CV Week: Итоговое задание¶

На лекции и семинаре мы разбирали как дистиллировать многошаговую диффузионную модель в малошагового студента, и тем самым будет работать на порядок быстрее учителя.

Один из подходов, который мы разбирали Consistency Distillation. В этом задании, мы закрепим материал, который был на лекции и семинаре и реализуем этот фреймворк, затрагивая различные нюансы.

В этом задании мы будем дистиллировать модель Stable Diffusion 1.5 (SD1.5) для генерации картинок по текстовому описанию.

Вам предстоит выполнить 8 небольших заданий, которые приведут нас к неплохой модели для генерации картинок за 4 шага, работая в органиченных условиях колаба.

In [ ]:
# torch 2.4.1+cu124
!pip install diffusers==0.30.3 peft==0.8.2 huggingface_hub==0.23.4

Теормин¶


Диффузионные модели¶

Задан прямой диффузионный процесс, который переводит чистые картинки в шум с помощью распределения $q(\mathbf{x}_t | \mathbf{x}_0)= {N}(\mathbf{x}_t | \alpha_t \mathbf{x}_0, \sigma^2_t I)$

Таким образом, мы можем получаться зашумленные картинки по следующей формуле: $\mathbf{x}_t = \alpha_t \mathbf{x}_0 + \sigma_t \epsilon$, где $\epsilon{\sim} {N}(0, I)$ (1)

$\alpha_t, \sigma_t$ задают процесс зашумления. Здесь мы будем иметь дело с variance preserving (VP) процессом, т. е., $\alpha^2_t = 1 - \sigma^2_t$.

Диффузионная модель (ДМ) пытается решить обратную задачу: из шума порождать новые картинки. Важно, что диффузионный процесс можно описать следующим обыкновенным дифференциальным уравнением (ОДУ):

$dx = \left[ f(\mathbf{x}, t) - \frac{1}{2} \nabla_{\mathbf{x}_t} \log p_t(\mathbf{x}) \right] dt$, (2)

где $f(\mathbf{x}, t)$ известен из заданного процесса зашумления, а $\nabla_{\mathbf{x}_t} \log p_t(\mathbf{x}_t)$ (скор функцию) оцениваем с помощью нейросети: $s_\theta(\mathbf{x}_t, t) \approx \nabla_{\mathbf{x}_t} \log p_t(\mathbf{x}_t)$. Таким образом, имея оценку на $\nabla_{\mathbf{x}_t} \log p_t(\mathbf{x})$, мы можем решить это ОДУ, стартуя со случайного шума, и получить картинку.

SD1.5 использует $\epsilon$-параметризацию, т.е., UNet пытается предсказать шум, который мы добавили на картинку по формуле (1). Оценку скор функции можно получить, пользуясь результатом, вытекающим из формулы Твидди: $s_\theta(\mathbf{x}_t, t) = - \frac{\epsilon_\theta(\mathbf{x}_t, t)} { \sigma_t}$

Чтобы решить ОДУ (2), нам нужно воспользоваться каким-то численным методом (солвером). В этом задании мы будем работать с не самым эффектным, но самым популярным солвером: DDIM, который является адаптированным методом Эйлера под диффузионный ОДУ.

Для VP процесса переход с помощью DDIM с шага $t$ на $s$ можно сделать следующим образом:

$ x_s = DDIM(\mathbf{x}_t, t, s) = \alpha_s \cdot \left(\frac{\mathbf{x}_t - \sigma_t \epsilon_\theta}{\alpha_t} \right) + \sigma_s \epsilon_\theta $

Этот переход можно интерпретировать так: получаем оценку на чистую картинку $\mathbf{x}_0$ на шаге $t$, используя $\frac{\mathbf{x}_t - \sigma_t \epsilon_\theta}{\alpha_t}$, а потом снова зашумляем эту оценку на шаг $s$ по формуле (1), но только используем не случайный шум, а шум предсказанный моделью $\epsilon_\theta$.

Используя DDIM для SD1.5, можем получать хорошие картинки за 50 шагов.

SD1.5 - латентная ДМ, т.е. модель работает не в пиксельном пространстве, а в латентном пространстве VAE. Таким образом SD1.5 состоит из следующих компонент:

  1. VAE - переводит $3{\times}512{\times}512$ картинки в латенты $4{\times}64{\times}64$ и может декодировать их обратно в картинки.
  2. Текстовый энкодер - извлекает текстовые признаки из промпта. Эти признаки будут подаваться в диффузионную модель, чтобы дать модели информацию, что именно хотим сгенерировать
  3. Диффузионная модель - UNet, работающий на "латентных картинках" $4{\times}64{\times}64$.

Консистенси модели¶

Общая идея¶

Главная цель дистилляции диффузии - уменьшить количество шагов ДМ, при этом сохранив высокое качество картинок.

Консистенси модели (Consistency Models | CM) - класс моделей, где мы хотим выучить "консистенси функцию" $f_\theta(\mathbf{x}_t)$ - с любой точки $\mathbf{x}_{t}$ траектории диффузионного ОДУ (2) сразу предсказывать $\mathbf{x}_{0}$ (чистые данные) за один шаг. Если мы идеально выучим консистенси функцию, то сможем шагать из чистого шума сразу в картинку, что супер эффективно в отличии от генерации ДМ.

Отметим, что консистенси модель можно учить как независимую генеративную модель, без предобученной ДМ, и в задании 3 вам предстоит подумать, как это можно сделать.


No description has been provided for this image

Консистенси дистилляция (Consistency Distillation | CD) - подход, когда для обучения CM, мы используем предобученную ДМ. ДМ нам дает качественную инициализацию модели и уже обученную скор функцию, что сильно упрощает сходимость консистенси моделей.

Обучение CM¶

No description has been provided for this image

Главная принцип обучения консистенси моделей заключается в попытке удовлетворить self-consistency св-ву: выход CM на двух соседних точках траектории $\mathbf{x}_{t}$ и $\mathbf{x}_{t-1}$ должен совпадать по какой-то мере близости, например L2 расстояние: $\lVert f_\theta(\mathbf{x}_{t-1}) - f_\theta(\mathbf{x}_{t}) \rVert^2_2$.

Заметим, что self-consistency св-во удовлетворить очень просто без какого-либо обучения, взяв, например $f_\theta(\mathbf{x}_{t}) \equiv 0$.

Поэтому, чтобы избежать вырожденных решений, нам необходимо выставить граничное условие (boundary condition), которое будет требовать, чтобы в самой левой точке траектории около 0, модель предсказывала картинку, которую получает на вход: $f_\theta(\mathbf{x}_{\epsilon}) = \mathbf{x}_{\epsilon}$.

Важное практическое замечание: Для обеих точек траектории мы применяем одну и ту же модель $f_\theta(\cdot)$. Но выход модели на шаге ${t-1}$ является "таргетом" для выхода модели на шаге $t$ и поэтому выполнение модели для шага $t-1$ выполняется в torch.no_grad режиме.

Как получаться две соседние точки на траектории ОДУ?

Берем случайную картинку $\mathbf{x}_0$ из датасета.

Точку $\mathbf{x}_t$ получаем с помощью прямого процесса зашумления: $\mathbf{x}_t = q(\mathbf{x}_t | \mathbf{x}_0)$

Чтобы получить соседнюю точку $\mathbf{x}_{t-1}$, нам нужно сделать шаг по траектории ОДУ, используя, например, DDIM солвер.

В консистенси дистилляции, мы делаем шаг предобученной ДМ: $\mathbf{x}_{t-1} = DDIM(\epsilon_\theta(\mathbf{x}_t, t), \mathbf{x}_t, t, t-1)$

In [1]:
from tqdm.auto import tqdm

import csv
import os
import torch
from PIL import Image
from diffusers import StableDiffusionPipeline, LCMScheduler, UNet2DConditionModel, DDIMScheduler

from peft import LoraConfig, get_peft_model, get_peft_model_state_dict

%matplotlib inline
import matplotlib.pyplot as plt
The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.
0it [00:00, ?it/s]
In [2]:
#---------------------
# Visualization utils
#---------------------

def visualize_images(images):
    assert len(images) == 4
    plt.figure(figsize=(12, 3))
    for i, image in enumerate(images):
        plt.subplot(1, 4, i+1)
        plt.imshow(image)
        plt.axis('off')

    plt.subplots_adjust(wspace=-0.01, hspace=-0.01)


#--------------
# Tensor utils
#--------------

def extract_into_tensor(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

#---------------
# Dataset utils
#---------------

class COCODataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, subset_name="train2014_5k", transform=None, max_cnt=None):
        """
        Arguments:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.extensions = (
            ".jpg",
            ".jpeg",
            ".png",
            ".ppm",
            ".bmp",
            ".pgm",
            ".tif",
            ".tiff",
            ".webp",
        )
        sample_dir = os.path.join(root_dir, subset_name)

        # Collect sample paths
        self.samples = sorted(
            [
                os.path.join(sample_dir, fname)
                for fname in os.listdir(sample_dir)
                if fname[-4:] in self.extensions
            ],
            key=lambda x: x.split("/")[-1].split(".")[0],
        )
        self.samples = (
            self.samples if max_cnt is None else self.samples[:max_cnt]
        )  # restrict num samples

        # Collect captions
        self.captions = {}
        with open(
            os.path.join(root_dir, f"{subset_name}.csv"), newline="\n"
        ) as csvfile:
            spamreader = csv.reader(csvfile, delimiter=",")
            for i, row in enumerate(spamreader):
                if i == 0:
                    continue
                self.captions[row[1]] = row[2]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        sample_path = self.samples[idx]
        sample = Image.open(sample_path).convert("RGB")

        if self.transform:
            sample = self.transform(sample)

        return {
            "image": sample,
            "text": self.captions[os.path.basename(sample_path)],
            "idxs": idx, }

Модель учителя (SD1.5)¶

Задание №1¶

Давайте для начала загрузим модель StableDiffusion 1.5 и сгенерируем ей картинки за 50 шагов.

Важно: для экономии памяти, загружаем все компоненты модели в FP16. Не забываем положить модель на GPU.

In [22]:
pipe = StableDiffusionPipeline.from_pretrained(
    'sd-legacy/stable-diffusion-v1-5',
    torch_dtype=torch.float16,
    safety_checker=None,
).to('cuda')

# Проверяем, что все компоненты модели в FP16 и на cuda
assert pipe.unet.dtype == torch.float16 and pipe.unet.device.type == 'cuda'
assert pipe.vae.dtype == torch.float16 and pipe.vae.device.type == 'cuda'
assert pipe.text_encoder.dtype == torch.float16 and pipe.text_encoder.device.type == 'cuda'

# Заменяем дефолтный сэмплер на DDIM
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.scheduler.timesteps = pipe.scheduler.timesteps.cuda()
pipe.scheduler.alphas_cumprod = pipe.scheduler.alphas_cumprod.cuda()

# Отдельно извлечем модель учителя, которую потом будем дистиллировать
teacher_unet = pipe.unet
model_index.json:   0%|          | 0.00/541 [00:00<?, ?B/s]
Fetching 13 files:   0%|          | 0/13 [00:00<?, ?it/s]
model.safetensors:   0%|          | 0.00/492M [00:00<?, ?B/s]
text_encoder/config.json:   0%|          | 0.00/617 [00:00<?, ?B/s]
scheduler/scheduler_config.json:   0%|          | 0.00/308 [00:00<?, ?B/s]
(…)ature_extractor/preprocessor_config.json:   0%|          | 0.00/342 [00:00<?, ?B/s]
tokenizer/special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]
tokenizer/vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]
tokenizer/tokenizer_config.json:   0%|          | 0.00/806 [00:00<?, ?B/s]
tokenizer/merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]
vae/config.json:   0%|          | 0.00/547 [00:00<?, ?B/s]
diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]
Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]
You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .

Теперь сгенерируем картинки за 50 шагов. Вам нужно написать вызов pipe и передать в него промпт, число шагов генерации, генератор случайных чисел, гайденс скейл и указать, чтобы сгенерировалось 4 картинки на промпт.

In [3]:
prompt = "A sad puppy with large eyes"
guidance_scale = 7.5
generator = torch.Generator('cuda').manual_seed(1)
In [ ]:
images = pipe(
    prompt=prompt,
    num_inference_steps=50,
    num_images_per_prompt=4,
    generator=generator,
    guidance_scale=guidance_scale,
).images

visualize_images(images)
  0%|          | 0/50 [00:00<?, ?it/s]
No description has been provided for this image

Давайте посмотрим, что выдаст модель за 4 шага. Все то же самое, что и выше, просто поменяем число шагов.

In [ ]:
generator = torch.Generator('cuda').manual_seed(1)

images = pipe(
    prompt=prompt,
    num_inference_steps=4,
    num_images_per_prompt=4,
    generator=generator,
    guidance_scale=guidance_scale,
).images

visualize_images(images)
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image

На 4 шагах картинки получаются размазанными. Давайте постараемся починить их.

Создаем датасет¶

Чтобы ДЗ было легко выполнимым на colab, мы будем учить консистенси модели на небольшой обучающей выборке из 5000 пар текст-картинка из COCO датасета. Интересное свойство консистенси моделей - они могут сходиться до адекватного качества за несколько сотен шагов. Качество все еще будет не идеальным, но фазовый переход уже должен быть заметен.

Данные можно загрузить с помощью команд в ячейке ниже. В локальной текущей директории ./ должны появиться:

  • Папка train2014_5k с 5000 картинками
  • Файл train2014_5k.csv с 5000 промптами

Данные парсятся корректным образом в уже реализованном классе COCODataset.

In [4]:
# Колаб
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
In [ ]:
# Kaggle
!pip install PyDrive
In [ ]:
# Загрузка
!wget https://storage.yandexcloud.net/yandex-research/train2014_5k.tar.gz
!tar -C "data" -xzf train2014_5k.tar.gz

Замечание: для более быстрого дебаггинга можете взять, например, 2500 картинок и прогнать на всей выборке только в самом конце. 2500 картинок должно быть достаточно для понимания корректно ли реализованы функции. Совсем для первичного дебаггинга можно взять еще меньше картинок.

In [ ]:
from torchvision import transforms

transform = transforms.Compose(
    [
        transforms.Resize(512),
        transforms.CenterCrop(512),
        transforms.ToTensor(),
        lambda x: 2 * x - 1,
    ]
)

colab = True
if colab:
    data_path = "./drive/MyDrive/cv_data"
else:
    data_path = "data"

dataset = COCODataset(data_path,
    subset_name="train2014_5k",
    transform=transform,
#     max_cnt=2500
)
assert len(dataset) == 5000 # 2500

batch_size = 8 # Рекоммендуемы размер батча на Colab

train_dataloader = torch.utils.data.DataLoader(
    dataset=dataset, shuffle=True, batch_size=batch_size, drop_last=True
)
In [5]:
@torch.no_grad()
def prepare_batch(batch, pipe):
    """
    Предобработка батча картинок и текстовых промптов.
    Маппим картинки в латентное пространство VAE.
    Извлекаем эмбеды промптов с помощью текстового энкодера.

    Params:

    Return:
        latents: torch.Tensor([B, 4, 64, 64], dtype=torch.float16)
        prompt_embeds: torch.Tensor([B, 77, D], dtype=torch.float16)
    """

    # Токенизируем промпты
    text_inputs = pipe.tokenizer(
        batch['text'],
        padding="max_length",
        max_length=pipe.tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )

    # Извлекаем эмбеды промптов с помощью текстового энкодера
    prompt_embeds = pipe.text_encoder(text_inputs.input_ids.cuda())[0]

    # Переводим картинки в латентное пространство VAE
    image = batch['image'].to("cuda", dtype=torch.float16)
    latents = pipe.vae.encode(image).latent_dist.sample()
    latents = latents * pipe.vae.config.scaling_factor
    return latents, prompt_embeds

Подготовка моделей и оптимизатора¶

Для начала создаем обучаемую модель: UNet инициализируемый весами SD1.5. Вам нужно воспользоваться классом UNet2DConditionModel и загрузить отдельно только UNet модель из SD1.5.

Отметим, что эта модель у нас будет храниться в полной точности FP32, потому что обучение параметров в FP16 может приводить к нестабильностям и низкому качеству.

In [6]:
unet = UNet2DConditionModel.from_pretrained(
    'sd-legacy/stable-diffusion-v1-5',
    subfolder='unet',
    torch_dtype=torch.float32,
).to('cuda').train()


assert unet.dtype == torch.float32
assert unet.training
/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: 
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
  warnings.warn(
unet/config.json:   0%|          | 0.00/743 [00:00<?, ?B/s]
diffusion_pytorch_model.safetensors:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

Для экономии памяти во время обучения будем учить не параметры самой модели, а добавим в нее обучаемые LoRA адаптеры с малым числом параметров.

LoRA представляет собой маленькую добавку к весам модели, где на одну матрицу весов $W \in \mathbb{R}^{m{\times}n} $ обучаются две низкоранговые матрицы $W_A \in \mathbb{R}^{k{\times}n}$ и $W_B \in \mathbb{R}^{k{\times}m}$, где $k$ - ранг матрицы сильно меньше $m$ и $n$.

Тем самым, новая обученная матрица весов может быть представлена как $\hat{W} = W + \Delta W = W + W^T_B W_A$.
Во время инференса $\Delta W$ можно вмержить в $W$ и получить итоговую модель. Также частая практика оставлять адаптеры как есть, чтобы была возможность для одной базовой модели учить несколько адаптеров под разные задачи и переключаться между ними по необходимости.

Если не мержить адаптеры, то вычисления для линейного слоя происходят как на картинке ниже.

In [7]:
# Указываем к каким слоям модели мы будет добавлять адаптеры.
lora_modules = [
    "to_q", "to_k", "to_v", "to_out.0", "proj_in", "proj_out",
    "ff.net.0.proj", "ff.net.2", "conv1", "conv2", "conv_shortcut",
    "downsamplers.0.conv", "upsamplers.0.conv", "time_emb_proj"
]
lora_config = LoraConfig(
    r=64, # задает ранг у матриц A и B в LoRA.
    target_modules=lora_modules
)
In [ ]:
# Создаем обертку исходной UNet модели с LoRA адаптерами, используя библиотеку PEFT
cm_unet = get_peft_model(unet, lora_config, adapter_name="ct")

# Включаем gradient checkpointing - важная техника для экономии памяти во время обучения
cm_unet.enable_gradient_checkpointing()

# Создаем оптимизатор
optimizer = torch.optim.AdamW(cm_unet.parameters(), lr=1e-4)
In [ ]:
# Задаем лосс функцию для CM обжектива. В базовом варианте разумно взять L2
# По умолчанию, она уже выдает усредненное значение по всем размерностям
mse_loss = torch.nn.functional.mse_loss

Задание №2 (0.5 балла, сдается в контесте)¶

Реализация шага DDIM¶

Шаг с помощью DDIM с $\mathbf{x}_t$ на $\mathbf{x}_s$ можно сделать следующим образом:

$ \mathbf{x}_s = DDIM(\epsilon_\theta, \mathbf{x}_t, t, s) = \alpha_s \cdot \left(\frac{\mathbf{x}_t - \sigma_t \epsilon_\theta}{\alpha_t} \right) + \sigma_s \epsilon_\theta $

Вам нужно реализовать эту формулу в уже готовом шаблоне ниже. Чтобы корректно выполнить задание, вам нужно задать $\alpha_t$ и $\sigma_t$ имея DDIMScheduler. **Обратите внимание на аттрибут *scheduler.alphas_cumprod***, который задает $\bar\alpha_{t} = \prod^t_{i=1} (1-\beta_i)$ в классической DDPM формулировке: Denoising Diffusion Probabilistic Models.

In [8]:
def ddim_solver_step(model_output, x_t, t, s, scheduler):
    """
    Шаг DDIM солвера для VP процесса зашумления и eps-prediction модели
    params:
        model_output: torch.Tensor[B, 4, 64, 64] - предсказание модели - шум eps
        x_t: torch.Tensor[B, 4, 64, 64] - сэмплы на шаге t
        t: torch.Tensor[B] - номер текущего шага
        s: torch.Tensor[B] - номер следующего шага
        scheduler: DDIMScheduler - расписание диффузионного процесса, чтобы получить alpha и sigma
    """
    alphas = torch.sqrt(scheduler.alphas_cumprod)
    sigmas = torch.sqrt(1 - scheduler.alphas_cumprod)

    sigmas_s = extract_into_tensor(sigmas, s, x_t.shape)
    alphas_s = extract_into_tensor(alphas, s, x_t.shape)

    sigmas_t = extract_into_tensor(sigmas, t, x_t.shape)
    alphas_t = extract_into_tensor(alphas, t, x_t.shape)

    # Выставляем крайние значения alpha и sigma, чтобы выполнялись граничные условия
    alphas_s[s == 0] = 1.0
    sigmas_s[s == 0] = 0.0

    alphas_t[t == 0] = 1.0
    sigmas_t[t == 0] = 0.0

    x_0 = (x_t - sigmas_t * model_output) / alphas_t # x0 оценка на шаге t
    x_s = alphas_s * x_0 + sigmas_s * model_output # Переход на шаг s
    return x_s

Реализация процесса зашумления (q sample)¶

Аналогично, нам нужен процесс зашумления $q(\mathbf{x}_t | \mathbf{x}_0)= {N}(\mathbf{x}_t | \alpha_t \mathbf{x}_0, \sigma^2_t I)$

$\mathbf{x}_t = \alpha_t \mathbf{x}_0 + \sigma_t \epsilon$, где $\epsilon{\sim} {N}(0, I)$

In [9]:
def q_sample(x, t, scheduler, noise=None):
    alphas = torch.sqrt(scheduler.alphas_cumprod)
    sigmas = torch.sqrt(1 - scheduler.alphas_cumprod)

    if noise is None:
        noise = torch.randn_like(x)

    sigmas_t = extract_into_tensor(sigmas, t, x.shape)
    alphas_t = extract_into_tensor(alphas, t, x.shape)

    alphas_t[t == 0] = 1.0
    sigmas_t[t == 0] = 0.0

    x_t = alphas_t * x + sigmas_t * noise
    return x_t

Consistency Training¶

Обучение консистенси моделей без учителя называется Consistency Training (CT). В таком случае CM можно рассматривать как отдельный вид генеративных моделей. Давайте начнем именно с этого подхода и обучим нашу первую консистенси модель на базе SD1.5.

Задание №3¶

Задание №3.1 (0.5 балла, сдается в контесте)¶

В консиcтенси дистилляции модель учителя используется для получения второй точки на траектории ODE. Можем ли мы попробовать оценить соседнюю точку аналитически?

Вам предлагается вывести это самим, используя формулу DDIM шага выше и вспомнив, как мы оцениваем скор функции в denoising score matching-e:

$\epsilon_\theta(x_t, t) = - \sigma_t s_\theta(x_t, t)$

$s_\theta(x_t, t) \approx \nabla_{x_t} \log q(x_t) = \mathop{\mathbb{E}}_{\mathbf{x}\sim p_{data}}\left [ \nabla_{\mathbf{x}_t} \log q(\mathbf{x}_t | \mathbf{x}) \vert \mathbf{x}_t \right ] \approx \nabla_{\mathbf{x}_t} \log q(\mathbf{x}_t \vert \mathbf{x})$


< YOUR DERIVATION HERE >

Я много что попробовал сдать в контесте, но зашла только эта формула.

Имеем процесс зашумления для $x_t$:

$x_t = \alpha_t x_0 + \sigma_t \epsilon$

Выражаем отсюда $\epsilon$:

$\epsilon = \frac{x_t - \alpha_t x_0}{\sigma_t}$

Подставляем в процесс зашумления для $x_s$ и упрощаем:

$x_s = \alpha_s x_0 + \sigma_s \epsilon$

$x_s = \alpha_s x_0 + \sigma_s \frac{x_t - \alpha_t x_0}{\sigma_t}$

$x_s = \frac{\sigma_s}{\sigma_t} x_t + (\alpha_s - \frac{\sigma_s}{\sigma_t} \alpha_t) x_0$


Если возникнут трудность, можно обратиться к оригинальной статье.

Теперь реализуем то, что у вас получилось в функции ниже.

In [10]:
def get_xs_from_xt_naive(
    x_0, x_t, t, s, # Не все эти аргументы могут быть вам нужны
    scheduler,
    noise=None,
    **kwargs
):
    """
    Получение точки x_s в CT режиме, т.е., аналитически.
    """

    alphas = torch.sqrt(scheduler.alphas_cumprod)
    sigmas = torch.sqrt(1 - scheduler.alphas_cumprod)

    sigmas_t = extract_into_tensor(sigmas, t, x_0.shape)
    alphas_t = extract_into_tensor(alphas, t, x_0.shape)
    sigmas_s = extract_into_tensor(sigmas, s, x_0.shape)
    alphas_s = extract_into_tensor(alphas, s, x_0.shape)

    alphas_t[t == 0] = 1.0
    sigmas_t[t == 0] = 0.0
    alphas_s[s == 0] = 1.0
    sigmas_s[s == 0] = 0.0

    if x_t is None:
        x_s = q_sample(x_0, t, scheduler, noise)

    else:
        x_s = (sigmas_s / sigmas_t) * x_t + (alphas_s - (sigmas_s / sigmas_t) * alphas_t) * x_0

    return x_s

Задание №3.2¶

Ниже предстален шаблон функции, которая считает лосс для консистенси моделей. Вам нужно правильно заполнить пропуски, чтобы получилась корректная функция.

In [ ]:
def cm_loss_template(
    latents, prompt_embeds, # батч латентов и текстовых эмбедов
    unet, scheduler,

    # Функции, которые будем постепенно менять из задания к заданию
    loss_fn: callable,
    get_boundary_timesteps: callable,
    get_xs_from_xt: callable,

    num_timesteps=1000,
    step_size=20, # Указываем с каким интервалом берем шаги s и t.
):
    # Сэмплируем случайные шаги t для каждого элемента батча t ~ U[step_size-1, 999]
    assert num_timesteps == 1000
    num_intervals = num_timesteps // step_size

    index = torch.randint(1, num_intervals, (len(latents),), device=latents.device).long() # [1, num_intervals]
    t = step_size * index - 1
    s = torch.clamp(t - step_size, min=0)
    boundary_timesteps = get_boundary_timesteps(
        s, num_timesteps=num_timesteps
    )

    # Сэмплируем x_t
    noise = torch.randn_like(latents)
    x_t = q_sample(latents, t, scheduler, noise)

    # with <YOUR CODE HERE>: # для реализации mixed-precision обучения в задании №4
    with torch.amp.autocast("cuda", torch.float16):  # Mixed precision
        noise_pred = unet(x_t.float(), t,
            encoder_hidden_states=prompt_embeds.float(),
        ).sample

    # Получаем оценку в граничной точке для x_t
    boundary_pred = ddim_solver_step(
        model_output=noise_pred, x_t=x_t, t=t, s=boundary_timesteps, scheduler=scheduler
    )

    # Получаем сэмпл x_s из x_t
    x_s = get_xs_from_xt(
        latents, x_t, t, s,
        scheduler,
        prompt_embeds=prompt_embeds,
        noise=noise,
    )

    # Предсказание "таргет моделью"
    with torch.no_grad(), torch.amp.autocast("cuda", torch.float16):
        target_noise_pred = unet(x_s, s, encoder_hidden_states=prompt_embeds).sample

    # Получаем оценку в граничной точке для x_s
    boundary_target = ddim_solver_step(
        model_output=target_noise_pred, x_t=x_s, t=s, s=boundary_timesteps, scheduler=scheduler
    )

    loss = loss_fn(boundary_pred, boundary_target)
    return loss
In [ ]:
import functools
In [ ]:
def get_zero_boundary_timesteps(t, **kwargs):
    """
    Определяем шаги где будут срабатывать граничные условия.
    Для классических СM это t=0.
    """
    return torch.zeros_like(t)


ct_loss = functools.partial(
    cm_loss_template,

    loss_fn=mse_loss,
    get_boundary_timesteps=get_zero_boundary_timesteps,
    get_xs_from_xt=get_xs_from_xt_naive
)
assert cm_unet.active_adapter == 'ct'

Задание №4¶

Эффективное обучение¶

Данное задание рассчитано на успешное выполнение на colab с бесплатной Tesla T4 c 15GB VRAM. Однако учить даже относительно небольшие T2I модели масштаба SD1.5 уже на коллабе в лоб проблематично.

Для этого нам нужно применить ряд инженерных техник, чтобы уместиться в данный бюджет и учиться за разумное время.

Список техник

  1. Включить gradient checkpointing для обучемой модели
  2. Добавить LoRA (Low Rank Adapters) адаптеры, чтобы учить не все веса, а только 10% добавочных весов
  3. Использовать gradient accumulation, чтобы делать итерацию обучения по бОльшему батчу, чем влезает по памяти
  4. Добавить mixed precision FP16/FP32 обучение модели для скорости. Обычно еще и память экономится, но в случае LoRA обучения + gradient checkpointing на память сильно влиять не должно, но зато станет быстрее.
  5. Мульти-GPU обучение - распределение вычислений по нескольким GPU.

1-2) Мы уже применили за вас выше

3-4) Предстоит реализовать вам самим в соотвествующей секции ниже

5 ) Недоступно, так как работаем на одной карточке

Обучающий цикл¶

Вам дан код обучения модель в полной точности (FP32) c батчом 8. К сожалению, на Tesla T4 мы не влезем по памяти. Поэтому в ячейке ниже вам нужно модифицировать цикл, чтобы он работал в mixed precision FP16 и добавить gradient accumulation.

Про реализацию mixed-precision в pytorch можно перейти по ссылке: Mixed-precision обучение

Обратите внимание: вам еще нужно добавить одну строчку кода в cm_loss_template в соответствующем плейсхолдере.

Замечание: В начале обучения значения лосса должны быть в окрестности 0.0007-0.001. Ничего страшного, что лосс не падает, для CM это нормально. В конце обучения лосс может доходить до 0.005-0.01

In [ ]:
from torch.cuda.amp import autocast, GradScaler

def train_loop(model, pipe, train_dataloader, optimizer, loss_fn, num_grad_accum=1):
    torch.cuda.empty_cache()

    # Создаем скейлер для mixed precision
    scaler = GradScaler()

    # Итерация по батчам
    for i, batch in enumerate(tqdm(train_dataloader)):

        latents, prompt_embeds = prepare_batch(batch, pipe)

        # Forward + backward с учетом gradient accumulation
        with autocast(dtype=torch.float16):  # Включаем mixed precision
            loss = loss_fn(latents, prompt_embeds, model, pipe.scheduler) / num_grad_accum

        # Backward с GradScaler
        scaler.scale(loss).backward()

        # Обновляем параметры каждые num_grad_accum шагов
        if (i + 1) % num_grad_accum == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)  # Сбрасываем градиенты

        # Логирование
        if i % num_grad_accum == 0 or (i + 1) % num_grad_accum == 0:
            print(f"Step {i + 1}, Loss: {loss.detach().item() * num_grad_accum}")
In [ ]:
num_grad_accum = 2 # обновляем параметры каждые 2 шага

train_loop(cm_unet, pipe, train_dataloader, optimizer, ct_loss, num_grad_accum)
<ipython-input-17-04d1c509c8fc>:7: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
  scaler = GradScaler()
  0%|          | 0/625 [00:00<?, ?it/s]
<ipython-input-17-04d1c509c8fc>:15: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with autocast(dtype=torch.float16):  # Включаем mixed precision
Step 1, Loss: 0.0009964737109839916
Step 2, Loss: 0.0007123646792024374
Step 3, Loss: 0.0025129171553999186
Step 4, Loss: 0.002672075293958187
Step 5, Loss: 0.0012879869900643826
Step 6, Loss: 0.0025040223263204098
Step 7, Loss: 0.0009721919195726514
Step 8, Loss: 0.0007318910211324692
Step 9, Loss: 0.0028819323051720858
Step 10, Loss: 0.0010622225236147642
Step 11, Loss: 0.0020778956823050976
Step 12, Loss: 0.0010942246299237013
Step 13, Loss: 0.0012017062399536371
Step 14, Loss: 0.0007995384512469172
Step 15, Loss: 0.0006579371402040124
Step 16, Loss: 0.0011498258681967854
Step 17, Loss: 0.0008381439838558435
Step 18, Loss: 0.0013155571650713682
Step 19, Loss: 0.001420671702362597
Step 20, Loss: 0.0007696166867390275
Step 21, Loss: 0.0008418669458478689
Step 22, Loss: 0.0015048839850351214
Step 23, Loss: 0.0008957093814387918
Step 24, Loss: 0.0025862189941108227
Step 25, Loss: 0.001034852466545999
Step 26, Loss: 0.00190110900439322
Step 27, Loss: 0.0006745924474671483
Step 28, Loss: 0.0020768148824572563
Step 29, Loss: 0.0013370923697948456
Step 30, Loss: 0.0008536883979104459
Step 31, Loss: 0.001032006461173296
Step 32, Loss: 0.0017720642499625683
Step 33, Loss: 0.0009183366782963276
Step 34, Loss: 0.0014106683665886521
Step 35, Loss: 0.0012036757543683052
Step 36, Loss: 0.001591663807630539
Step 37, Loss: 0.0030212663114070892
Step 38, Loss: 0.0023587362375110388
Step 39, Loss: 0.0014438326470553875
Step 40, Loss: 0.0007154847262427211
Step 41, Loss: 0.001517471857368946
Step 42, Loss: 0.0009275604970753193
Step 43, Loss: 0.0016475626034662127
Step 44, Loss: 0.0015158411115407944
Step 45, Loss: 0.0010888006072491407
Step 46, Loss: 0.0008507425663992763
Step 47, Loss: 0.000835998565889895
Step 48, Loss: 0.002443537348881364
Step 49, Loss: 0.0011052710469812155
Step 50, Loss: 0.001367520191706717
Step 51, Loss: 0.0014835491310805082
Step 52, Loss: 0.0021522396709769964
Step 53, Loss: 0.001071461010724306
Step 54, Loss: 0.0008708474924787879
Step 55, Loss: 0.0016729355556890368
Step 56, Loss: 0.0012967146467417479
Step 57, Loss: 0.0024283856619149446
Step 58, Loss: 0.004348785616457462
Step 59, Loss: 0.0018030828796327114
Step 60, Loss: 0.005939250811934471
Step 61, Loss: 0.0015603761421516538
Step 62, Loss: 0.0015945896739140153
Step 63, Loss: 0.001367853139527142
Step 64, Loss: 0.0010217612143605947
Step 65, Loss: 0.0018260078504681587
Step 66, Loss: 0.0016675188671797514
Step 67, Loss: 0.002414319198578596
Step 68, Loss: 0.0022158417850732803
Step 69, Loss: 0.0013161207316443324
Step 70, Loss: 0.0009690714068710804
Step 71, Loss: 0.0037761759012937546
Step 72, Loss: 0.0013274485245347023
Step 73, Loss: 0.0017683268524706364
Step 74, Loss: 0.0022360978182405233
Step 75, Loss: 0.0018206157255917788
Step 76, Loss: 0.0019233342027291656
Step 77, Loss: 0.0019138399511575699
Step 78, Loss: 0.0018020558636635542
Step 79, Loss: 0.0013781136367470026
Step 80, Loss: 0.003604952711611986
Step 81, Loss: 0.002405259758234024
Step 82, Loss: 0.0013625816209241748
Step 83, Loss: 0.00869310274720192
Step 84, Loss: 0.00275280955247581
Step 85, Loss: 0.004032185301184654
Step 86, Loss: 0.00870254822075367
Step 87, Loss: 0.002071268390864134
Step 88, Loss: 0.0018498199060559273
Step 89, Loss: 0.00991445779800415
Step 90, Loss: 0.0030165230855345726
Step 91, Loss: 0.003333171596750617
Step 92, Loss: 0.003185434266924858
Step 93, Loss: 0.0021970977541059256
Step 94, Loss: 0.003695380873978138
Step 95, Loss: 0.013177640736103058
Step 96, Loss: 0.005383758805692196
Step 97, Loss: 0.005162822548300028
Step 98, Loss: 0.005881481803953648
Step 99, Loss: 0.005417659878730774
Step 100, Loss: 0.007452243007719517
Step 101, Loss: 0.008300170302391052
Step 102, Loss: 0.006970586255192757
Step 103, Loss: 0.00356032932177186
Step 104, Loss: 0.004906745627522469
Step 105, Loss: 0.0038089922163635492
Step 106, Loss: 0.0015337818767875433
Step 107, Loss: 0.00300682638771832
Step 108, Loss: 0.008650153875350952
Step 109, Loss: 0.004231759812682867
Step 110, Loss: 0.0029275708366185427
Step 111, Loss: 0.003925336059182882
Step 112, Loss: 0.004475311376154423
Step 113, Loss: 0.0034971064887940884
Step 114, Loss: 0.004230810329318047
Step 115, Loss: 0.0024554363917559385
Step 116, Loss: 0.0021734843030571938
Step 117, Loss: 0.008273965679109097
Step 118, Loss: 0.002536300104111433
Step 119, Loss: 0.003144968766719103
Step 120, Loss: 0.0033394754864275455
Step 121, Loss: 0.0019214244093745947
Step 122, Loss: 0.001339460490271449
Step 123, Loss: 0.003105826210230589
Step 124, Loss: 0.0019138716161251068
Step 125, Loss: 0.0009612067369744182
Step 126, Loss: 0.0028296932578086853
Step 127, Loss: 0.0021417844109237194
Step 128, Loss: 0.004075353033840656
Step 129, Loss: 0.004901350475847721
Step 130, Loss: 0.0011327839456498623
Step 131, Loss: 0.004618014208972454
Step 132, Loss: 0.002003858797252178
Step 133, Loss: 0.003282646182924509
Step 134, Loss: 0.0014006216078996658
Step 135, Loss: 0.0014191106893122196
Step 136, Loss: 0.0037141223438084126
Step 137, Loss: 0.002633387455716729
Step 138, Loss: 0.002554066013544798
Step 139, Loss: 0.0018516054842621088
Step 140, Loss: 0.0018092331010848284
Step 141, Loss: 0.0017135003581643105
Step 142, Loss: 0.0036260229535400867
Step 143, Loss: 0.0030195917934179306
Step 144, Loss: 0.0009842927101999521
Step 145, Loss: 0.0015842400025576353
Step 146, Loss: 0.0017116828821599483
Step 147, Loss: 0.0015447353944182396
Step 148, Loss: 0.001342438394203782
Step 149, Loss: 0.002777844201773405
Step 150, Loss: 0.001807103049941361
Step 151, Loss: 0.0025249419268220663
Step 152, Loss: 0.0017091265181079507
Step 153, Loss: 0.0024985510390251875
Step 154, Loss: 0.001993537414819002
Step 155, Loss: 0.0019271315541118383
Step 156, Loss: 0.0018232737202197313
Step 157, Loss: 0.002669207751750946
Step 158, Loss: 0.0013759227003902197
Step 159, Loss: 0.0024884392041713
Step 160, Loss: 0.003064761171117425
Step 161, Loss: 0.0016928670229390264
Step 162, Loss: 0.0018594666616991162
Step 163, Loss: 0.0022970836143940687
Step 164, Loss: 0.0031568421982228756
Step 165, Loss: 0.007324790116399527
Step 166, Loss: 0.001695355400443077
Step 167, Loss: 0.0017826779512688518
Step 168, Loss: 0.0011040661484003067
Step 169, Loss: 0.0010679103434085846
Step 170, Loss: 0.0011576625984162092
Step 171, Loss: 0.003895317669957876
Step 172, Loss: 0.001107738702557981
Step 173, Loss: 0.002163609955459833
Step 174, Loss: 0.0024494314566254616
Step 175, Loss: 0.001972742145881057
Step 176, Loss: 0.001785499625839293
Step 177, Loss: 0.002004774287343025
Step 178, Loss: 0.0009915400296449661
Step 179, Loss: 0.0015638095792382956
Step 180, Loss: 0.0013448246754705906
Step 181, Loss: 0.0031648697331547737
Step 182, Loss: 0.001704327529296279
Step 183, Loss: 0.002551598474383354
Step 184, Loss: 0.002767860423773527
Step 185, Loss: 0.002158526098355651
Step 186, Loss: 0.001088866381905973
Step 187, Loss: 0.0024332161992788315
Step 188, Loss: 0.001555718365125358
Step 189, Loss: 0.0014911929611116648
Step 190, Loss: 0.0012259191134944558
Step 191, Loss: 0.0019990727305412292
Step 192, Loss: 0.004777118563652039
Step 193, Loss: 0.0017879370134323835
Step 194, Loss: 0.0022225005086511374
Step 195, Loss: 0.0018355753272771835
Step 196, Loss: 0.0021953212562948465
Step 197, Loss: 0.0019467687234282494
Step 198, Loss: 0.0033462434075772762
Step 199, Loss: 0.0018740312661975622
Step 200, Loss: 0.0010305580217391253
Step 201, Loss: 0.0024460656568408012
Step 202, Loss: 0.0010713809169828892
Step 203, Loss: 0.004301274195313454
Step 204, Loss: 0.002050131093710661
Step 205, Loss: 0.0015187339158728719
Step 206, Loss: 0.0028489278629422188
Step 207, Loss: 0.0027420278638601303
Step 208, Loss: 0.002065301174297929
Step 209, Loss: 0.0020583991426974535
Step 210, Loss: 0.0030916407704353333
Step 211, Loss: 0.003561109770089388
Step 212, Loss: 0.0018161950865760446
Step 213, Loss: 0.001744080800563097
Step 214, Loss: 0.003127105999737978
Step 215, Loss: 0.0030535277910530567
Step 216, Loss: 0.0025577410124242306
Step 217, Loss: 0.002055512275546789
Step 218, Loss: 0.0013864929787814617
Step 219, Loss: 0.0029660311993211508
Step 220, Loss: 0.0013375321868807077
Step 221, Loss: 0.002898484468460083
Step 222, Loss: 0.0024262629449367523
Step 223, Loss: 0.003130171447992325
Step 224, Loss: 0.0014638151042163372
Step 225, Loss: 0.00363878789357841
Step 226, Loss: 0.003392276354134083
Step 227, Loss: 0.002066464629024267
Step 228, Loss: 0.003555938834324479
Step 229, Loss: 0.0033692405559122562
Step 230, Loss: 0.0010493635199964046
Step 231, Loss: 0.0010821252362802625
Step 232, Loss: 0.0017637109849601984
Step 233, Loss: 0.0012440495193004608
Step 234, Loss: 0.0016341344453394413
Step 235, Loss: 0.0015762588009238243
Step 236, Loss: 0.0010235632071271539
Step 237, Loss: 0.001646367833018303
Step 238, Loss: 0.002081636106595397
Step 239, Loss: 0.0012786537408828735
Step 240, Loss: 0.002571119461208582
Step 241, Loss: 0.0013509748969227076
Step 242, Loss: 0.003147643292322755
Step 243, Loss: 0.0014911983162164688
Step 244, Loss: 0.0029995788354426622
Step 245, Loss: 0.0022395257838070393
Step 246, Loss: 0.0013257043901830912
Step 247, Loss: 0.0016677497187629342
Step 248, Loss: 0.001462545245885849
Step 249, Loss: 0.003457159036770463
Step 250, Loss: 0.0013242086861282587
Step 251, Loss: 0.0021175735164433718
Step 252, Loss: 0.0023571918718516827
Step 253, Loss: 0.0030575694981962442
Step 254, Loss: 0.0024516936391592026
Step 255, Loss: 0.002214903710409999
Step 256, Loss: 0.0020048075821250677
Step 257, Loss: 0.0030402978882193565
Step 258, Loss: 0.0031706318259239197
Step 259, Loss: 0.0019083842635154724
Step 260, Loss: 0.0021437020041048527
Step 261, Loss: 0.001716109924018383
Step 262, Loss: 0.0012771545443683863
Step 263, Loss: 0.0017771539278328419
Step 264, Loss: 0.0015819550026208162
Step 265, Loss: 0.002498017158359289
Step 266, Loss: 0.0016480679623782635
Step 267, Loss: 0.0022397267166525126
Step 268, Loss: 0.002551595214754343
Step 269, Loss: 0.001356970053166151
Step 270, Loss: 0.002087733708322048
Step 271, Loss: 0.0026639692950993776
Step 272, Loss: 0.0052254656329751015
Step 273, Loss: 0.0013680905103683472
Step 274, Loss: 0.0036784247495234013
Step 275, Loss: 0.000992199988104403
Step 276, Loss: 0.0018863962031900883
Step 277, Loss: 0.0029394528828561306
Step 278, Loss: 0.001596520422026515
Step 279, Loss: 0.004238415509462357
Step 280, Loss: 0.0030783922411501408
Step 281, Loss: 0.0021919207647442818
Step 282, Loss: 0.002420907374471426
Step 283, Loss: 0.0017966690938919783
Step 284, Loss: 0.001461828825995326
Step 285, Loss: 0.0036598462611436844
Step 286, Loss: 0.0012443051673471928
Step 287, Loss: 0.00148701760917902
Step 288, Loss: 0.0017320663901045918
Step 289, Loss: 0.0014320281334221363
Step 290, Loss: 0.008277904242277145
Step 291, Loss: 0.0017746041994541883
Step 292, Loss: 0.001563466852530837
Step 293, Loss: 0.002196545246988535
Step 294, Loss: 0.003776966128498316
Step 295, Loss: 0.001591964508406818
Step 296, Loss: 0.001984924543648958
Step 297, Loss: 0.002549830824136734
Step 298, Loss: 0.0013229760807007551
Step 299, Loss: 0.0014838275965303183
Step 300, Loss: 0.0012145370710641146
Step 301, Loss: 0.002189289778470993
Step 302, Loss: 0.0022014169953763485
Step 303, Loss: 0.001216724282130599
Step 304, Loss: 0.0013273911317810416
Step 305, Loss: 0.0013471723068505526
Step 306, Loss: 0.002097218995913863
Step 307, Loss: 0.0031583341769874096
Step 308, Loss: 0.0027093582320958376
Step 309, Loss: 0.004970692563802004
Step 310, Loss: 0.0018247850239276886
Step 311, Loss: 0.00376988691277802
Step 312, Loss: 0.0016975709004327655
Step 313, Loss: 0.0009460895671509206
Step 314, Loss: 0.003457896178588271
Step 315, Loss: 0.0015740500530228019
Step 316, Loss: 0.0018885633908212185
Step 317, Loss: 0.002235714579001069
Step 318, Loss: 0.0029280800372362137
Step 319, Loss: 0.0021540713496506214
Step 320, Loss: 0.0022602220997214317
Step 321, Loss: 0.0018740276573225856
Step 322, Loss: 0.0022994857281446457
Step 323, Loss: 0.0013489238917827606
Step 324, Loss: 0.00139906897675246
Step 325, Loss: 0.0030182464979588985
Step 326, Loss: 0.0015988049563020468
Step 327, Loss: 0.0018778664525598288
Step 328, Loss: 0.001625601202249527
Step 329, Loss: 0.0018519361037760973
Step 330, Loss: 0.0017070379108190536
Step 331, Loss: 0.0021748319268226624
Step 332, Loss: 0.0017367968102917075
Step 333, Loss: 0.0014846296980977058
Step 334, Loss: 0.002475373214110732
Step 335, Loss: 0.002954889554530382
Step 336, Loss: 0.0046874405816197395
Step 337, Loss: 0.00242308946326375
Step 338, Loss: 0.001720296568237245
Step 339, Loss: 0.0010912759462371469
Step 340, Loss: 0.0018609871622174978
Step 341, Loss: 0.00340099073946476
Step 342, Loss: 0.004037333186715841
Step 343, Loss: 0.0028379643335938454
Step 344, Loss: 0.0013051190180703998
Step 345, Loss: 0.002631761133670807
Step 346, Loss: 0.0025198571383953094
Step 347, Loss: 0.0032972507178783417
Step 348, Loss: 0.003289341926574707
Step 349, Loss: 0.0017978892428800464
Step 350, Loss: 0.0033561696764081717
Step 351, Loss: 0.0021787723526358604
Step 352, Loss: 0.003375221509486437
Step 353, Loss: 0.001669022487476468
Step 354, Loss: 0.002778403926640749
Step 355, Loss: 0.0023496164940297604
Step 356, Loss: 0.0018149161478504539
Step 357, Loss: 0.0014565347228199244
Step 358, Loss: 0.0016320422291755676
Step 359, Loss: 0.002429002895951271
Step 360, Loss: 0.0029447644483298063
Step 361, Loss: 0.0025833663530647755
Step 362, Loss: 0.0026737430598586798
Step 363, Loss: 0.0021661401260644197
Step 364, Loss: 0.0017511245096102357
Step 365, Loss: 0.002903490327298641
Step 366, Loss: 0.005947783123701811
Step 367, Loss: 0.004735832568258047
Step 368, Loss: 0.0019470665138214827
Step 369, Loss: 0.004013554193079472
Step 370, Loss: 0.002077719196677208
Step 371, Loss: 0.002286328934133053
Step 372, Loss: 0.003162218490615487
Step 373, Loss: 0.003371443133801222
Step 374, Loss: 0.0022568991407752037
Step 375, Loss: 0.0013082942459732294
Step 376, Loss: 0.0021568434312939644
Step 377, Loss: 0.002213746774941683
Step 378, Loss: 0.0030593108385801315
Step 379, Loss: 0.0019463550997897983
Step 380, Loss: 0.003151160664856434
Step 381, Loss: 0.0022312854416668415
Step 382, Loss: 0.0017455383203923702
Step 383, Loss: 0.0027218933682888746
Step 384, Loss: 0.0014385171234607697
Step 385, Loss: 0.004356844816356897
Step 386, Loss: 0.0030530733056366444
Step 387, Loss: 0.0015455837128683925
Step 388, Loss: 0.0036600164603441954
Step 389, Loss: 0.003361268900334835
Step 390, Loss: 0.0047602481208741665
Step 391, Loss: 0.0034189848229289055
Step 392, Loss: 0.0022142487578094006
Step 393, Loss: 0.0018446831963956356
Step 394, Loss: 0.0021419369149953127
Step 395, Loss: 0.0016223136335611343
Step 396, Loss: 0.002930941991508007
Step 397, Loss: 0.002225317992269993
Step 398, Loss: 0.0027877832762897015
Step 399, Loss: 0.002290947362780571
Step 400, Loss: 0.0028579893987625837
Step 401, Loss: 0.0017973837675526738
Step 402, Loss: 0.0019481240306049585
Step 403, Loss: 0.0030283827800303698
Step 404, Loss: 0.0019655197393149137
Step 405, Loss: 0.0027977675199508667
Step 406, Loss: 0.0035928445868194103
Step 407, Loss: 0.001847823616117239
Step 408, Loss: 0.001033484935760498
Step 409, Loss: 0.002944400068372488
Step 410, Loss: 0.0029716547578573227
Step 411, Loss: 0.0015361449914053082
Step 412, Loss: 0.0024315500631928444
Step 413, Loss: 0.0021710386499762535
Step 414, Loss: 0.0012272449675947428
Step 415, Loss: 0.00263773906044662
Step 416, Loss: 0.0038949991576373577
Step 417, Loss: 0.0014056290965527296
Step 418, Loss: 0.001367239747196436
Step 419, Loss: 0.003612218890339136
Step 420, Loss: 0.0015791754703968763
Step 421, Loss: 0.0046881032176315784
Step 422, Loss: 0.002825289499014616
Step 423, Loss: 0.0012546565849334002
Step 424, Loss: 0.0020847772248089314
Step 425, Loss: 0.002753379987552762
Step 426, Loss: 0.00311181228607893
Step 427, Loss: 0.0018361861584708095
Step 428, Loss: 0.0020177848637104034
Step 429, Loss: 0.0017246806528419256
Step 430, Loss: 0.0022823826875537634
Step 431, Loss: 0.0038195070810616016
Step 432, Loss: 0.002087909961119294
Step 433, Loss: 0.001844338490627706
Step 434, Loss: 0.0013272862415760756
Step 435, Loss: 0.0024121454916894436
Step 436, Loss: 0.004444989841431379
Step 437, Loss: 0.001688591786660254
Step 438, Loss: 0.0016547583509236574
Step 439, Loss: 0.0028140549547970295
Step 440, Loss: 0.0022476916201412678
Step 441, Loss: 0.0026859110221266747
Step 442, Loss: 0.0029036798514425755
Step 443, Loss: 0.002507369965314865
Step 444, Loss: 0.002403222257271409
Step 445, Loss: 0.0014040693640708923
Step 446, Loss: 0.002955233445391059
Step 447, Loss: 0.0043233102187514305
Step 448, Loss: 0.004444340243935585
Step 449, Loss: 0.002760730916634202
Step 450, Loss: 0.0037729348987340927
Step 451, Loss: 0.0037752510979771614
Step 452, Loss: 0.002742704004049301
Step 453, Loss: 0.001664300449192524
Step 454, Loss: 0.003079385496675968
Step 455, Loss: 0.001461958047002554
Step 456, Loss: 0.002757724840193987
Step 457, Loss: 0.0021401511039584875
Step 458, Loss: 0.0033850609324872494
Step 459, Loss: 0.003520863363519311
Step 460, Loss: 0.003346397541463375
Step 461, Loss: 0.0019605071283876896
Step 462, Loss: 0.003859851509332657
Step 463, Loss: 0.003244594670832157
Step 464, Loss: 0.004167188890278339
Step 465, Loss: 0.0025255850050598383
Step 466, Loss: 0.003343079937621951
Step 467, Loss: 0.0036160782910883427
Step 468, Loss: 0.004407159052789211
Step 469, Loss: 0.0034948645625263453
Step 470, Loss: 0.0021185046061873436
Step 471, Loss: 0.002002623863518238
Step 472, Loss: 0.0018690098077058792
Step 473, Loss: 0.004508321639150381
Step 474, Loss: 0.0017977735260501504
Step 475, Loss: 0.0017795597668737173
Step 476, Loss: 0.0014956528320908546
Step 477, Loss: 0.001654409570619464
Step 478, Loss: 0.002867832314223051
Step 479, Loss: 0.0021172226406633854
Step 480, Loss: 0.0013408382656052709
Step 481, Loss: 0.00435650022700429
Step 482, Loss: 0.0030726974364370108
Step 483, Loss: 0.0013296615798026323
Step 484, Loss: 0.001793580362573266
Step 485, Loss: 0.0016964077949523926
Step 486, Loss: 0.001986682415008545
Step 487, Loss: 0.0020661265589296818
Step 488, Loss: 0.0026734794955700636
Step 489, Loss: 0.003145786002278328
Step 490, Loss: 0.0022016228176653385
Step 491, Loss: 0.002660672180354595
Step 492, Loss: 0.001197907142341137
Step 493, Loss: 0.003564678831025958
Step 494, Loss: 0.002485886448994279
Step 495, Loss: 0.0019238153472542763
Step 496, Loss: 0.0022778897546231747
Step 497, Loss: 0.002266446128487587
Step 498, Loss: 0.004147983156144619
Step 499, Loss: 0.004518461879342794
Step 500, Loss: 0.002260998822748661
Step 501, Loss: 0.0029638917185366154
Step 502, Loss: 0.0026786173693835735
Step 503, Loss: 0.0016313818050548434
Step 504, Loss: 0.0017759317997843027
Step 505, Loss: 0.0021710728760808706
Step 506, Loss: 0.0029801903292536736
Step 507, Loss: 0.0018787817098200321
Step 508, Loss: 0.004778842441737652
Step 509, Loss: 0.0021530394442379475
Step 510, Loss: 0.004105462692677975
Step 511, Loss: 0.003464809153228998
Step 512, Loss: 0.0026504206471145153
Step 513, Loss: 0.0022748950868844986
Step 514, Loss: 0.001675811829045415
Step 515, Loss: 0.0021095022093504667
Step 516, Loss: 0.0020678124856203794
Step 517, Loss: 0.0029012225568294525
Step 518, Loss: 0.004787777550518513
Step 519, Loss: 0.0035675661638379097
Step 520, Loss: 0.0033075539395213127
Step 521, Loss: 0.002297051018103957
Step 522, Loss: 0.0034762569703161716
Step 523, Loss: 0.0032242024317383766
Step 524, Loss: 0.002810206264257431
Step 525, Loss: 0.0016817448195070028
Step 526, Loss: 0.0018216629978269339
Step 527, Loss: 0.002537459833547473
Step 528, Loss: 0.003333230037242174
Step 529, Loss: 0.001697663450613618
Step 530, Loss: 0.0038388343527913094
Step 531, Loss: 0.002034861594438553
Step 532, Loss: 0.002954866038635373
Step 533, Loss: 0.003910453990101814
Step 534, Loss: 0.0020375121384859085
Step 535, Loss: 0.002452058019116521
Step 536, Loss: 0.0020335500594228506
Step 537, Loss: 0.0019645700231194496
Step 538, Loss: 0.0033477586694061756
Step 539, Loss: 0.004655073396861553
Step 540, Loss: 0.001522570033557713
Step 541, Loss: 0.003043032716959715
Step 542, Loss: 0.0023580617271363735
Step 543, Loss: 0.0036311703734099865
Step 544, Loss: 0.002830999670550227
Step 545, Loss: 0.0026946349535137415
Step 546, Loss: 0.0027369363233447075
Step 547, Loss: 0.0014522189740091562
Step 548, Loss: 0.0018265830585733056
Step 549, Loss: 0.0013534543104469776
Step 550, Loss: 0.001750377588905394
Step 551, Loss: 0.0018752054311335087
Step 552, Loss: 0.003588886931538582
Step 553, Loss: 0.0023917958606034517
Step 554, Loss: 0.002180017763748765
Step 555, Loss: 0.0013729555066674948
Step 556, Loss: 0.002268531359732151
Step 557, Loss: 0.001678610104136169
Step 558, Loss: 0.0031022634357213974
Step 559, Loss: 0.0020591646898537874
Step 560, Loss: 0.003965404815971851
Step 561, Loss: 0.0015510644298046827
Step 562, Loss: 0.0015438924310728908
Step 563, Loss: 0.0026267962530255318
Step 564, Loss: 0.003942787181586027
Step 565, Loss: 0.0023364638909697533
Step 566, Loss: 0.0017311549745500088
Step 567, Loss: 0.0023734201677143574
Step 568, Loss: 0.0023933923803269863
Step 569, Loss: 0.0020020711235702038
Step 570, Loss: 0.001555254915729165
Step 571, Loss: 0.0016916969325393438
Step 572, Loss: 0.001597439288161695
Step 573, Loss: 0.004049327224493027
Step 574, Loss: 0.004352931398898363
Step 575, Loss: 0.002702228259295225
Step 576, Loss: 0.004212164785712957
Step 577, Loss: 0.002407374791800976
Step 578, Loss: 0.0033760140649974346
Step 579, Loss: 0.003392706857994199
Step 580, Loss: 0.0023206849582493305
Step 581, Loss: 0.0013627472799271345
Step 582, Loss: 0.002573030535131693
Step 583, Loss: 0.0023301132023334503
Step 584, Loss: 0.00240900507196784
Step 585, Loss: 0.004537998698651791
Step 586, Loss: 0.0022711679339408875
Step 587, Loss: 0.004056943580508232
Step 588, Loss: 0.002882580505684018
Step 589, Loss: 0.002988439751788974
Step 590, Loss: 0.005434936378151178
Step 591, Loss: 0.0025192226748913527
Step 592, Loss: 0.0028609472792595625
Step 593, Loss: 0.0015962341567501426
Step 594, Loss: 0.004357056692242622
Step 595, Loss: 0.0018941131420433521
Step 596, Loss: 0.0015789249446243048
Step 597, Loss: 0.0023252214305102825
Step 598, Loss: 0.0018447035690769553
Step 599, Loss: 0.002543957205489278
Step 600, Loss: 0.002252105623483658
Step 601, Loss: 0.002090814057737589
Step 602, Loss: 0.002668071072548628
Step 603, Loss: 0.001780823222361505
Step 604, Loss: 0.00167083612177521
Step 605, Loss: 0.0043045347556471825
Step 606, Loss: 0.0017265079077333212
Step 607, Loss: 0.001871362328529358
Step 608, Loss: 0.003699228400364518
Step 609, Loss: 0.00566853117197752
Step 610, Loss: 0.002942024264484644
Step 611, Loss: 0.0027231022249907255
Step 612, Loss: 0.007950660772621632
Step 613, Loss: 0.002311816206201911
Step 614, Loss: 0.002136490074917674
Step 615, Loss: 0.0017656716518104076
Step 616, Loss: 0.001274152658879757
Step 617, Loss: 0.0034350096248090267
Step 618, Loss: 0.001877711503766477
Step 619, Loss: 0.001676612184382975
Step 620, Loss: 0.002426299499347806
Step 621, Loss: 0.002206122037023306
Step 622, Loss: 0.002395118586719036
Step 623, Loss: 0.002627358539029956
Step 624, Loss: 0.0031444155611097813
Step 625, Loss: 0.0015789918834343553
In [ ]:
# torch.save(cm_unet.state_dict(), '/content/drive/MyDrive/cv_model/cm_model.pth')

Задание 5¶

Генерация с помощью обученной консистенси модели¶

Настало время погенерировать картинки с помощью нашей модели. Напомним, что мы не можем для консистенси моделей использовать DDIM и другие классические солверы для диффузии. Нам нужен специальный сэмплер для CM, который схематично изображен на картинке ниже:

No description has been provided for this image

Чуть более формально:

$x_{t_n} \sim {N}(0, I)$

$for\ t_i \in [t_n, ..., t_1]:$

  • $\epsilon \leftarrow unet(x_{t_i})$

  • $x_0 \leftarrow DDIM(\epsilon, x_{t_i}, t_i, 0)$

  • $x_{t_{i-1}} \leftarrow q(x_{t_{i-1}} | x_0)$

Classifier-free guidance (CFG)

Также вам надо реализовать поддержку CFG в CM сэмплирование. Вспомним формулу:

$\epsilon_w = {\color{blue}{\epsilon_{uncond}}} + w \cdot (\epsilon_{cond} - \epsilon_{uncond})$, где $w \geq 1$

Обратим внимание, что режим "без гайденса" соотвествует $w = 1$, что немного контринтуитивно, но в большинстве реализаций будет встречаться именно такой вид этой формулы.

In [11]:
@torch.no_grad()
def consistency_sampling(
    pipe,
    prompt,
    num_inference_steps=4,
    generator=None,
    num_images_per_prompt=4,
    guidance_scale=1
):
    if prompt is not None and isinstance(prompt, str):
        batch_size = 1
    elif prompt is not None and isinstance(prompt, list):
        batch_size = len(prompt)

    device = pipe._execution_device

    # Извлекаем эмбеды из текстовых промптов. Реализуйте вызов pipe.encode_prompt
    do_classifier_free_guidance = guidance_scale > 0
    prompt_embeds = pipe.encode_prompt(
        prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance
    )[0]
    null_prompt_embeds = pipe.encode_prompt(
        [""] * batch_size, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance
    )[0]

    assert prompt_embeds.dtype == null_prompt_embeds.dtype == torch.float16

    # Настраиваем параметры scheduler-a
    assert pipe.scheduler.config['timestep_spacing'] == 'trailing'
    pipe.scheduler.set_timesteps(num_inference_steps)

    # Создаем батч латентов из N(0,I)
    latents = torch.randn(
        (batch_size * num_images_per_prompt, pipe.unet.in_channels, pipe.unet.sample_size, pipe.unet.sample_size),
        generator=generator,
        device=device,
        dtype=torch.float16,
    )

    for i, t in enumerate(tqdm(pipe.scheduler.timesteps)):
        t = torch.tensor([t] * len(latents)).to(device)
        zero_t = torch.tensor([0] * len(latents)).to(device)

        cond_noise_pred = pipe.unet(
            latents, t, encoder_hidden_states=prompt_embeds
        ).sample

        if do_classifier_free_guidance:
            uncond_noise_pred = pipe.unet(
                latents, t, encoder_hidden_states=null_prompt_embeds
            ).sample

            noise_pred = uncond_noise_pred + guidance_scale * (cond_noise_pred - uncond_noise_pred)
        else:
            noise_pred = cond_noise_pred

        # Получаем x_0 оценку из x_t
        x_0 = ddim_solver_step(
            model_output=noise_pred, x_t=latents, t=t, s=zero_t, scheduler=pipe.scheduler
        )

        if i + 1 < num_inference_steps:
            # Переход на следующий шаг
            s = pipe.scheduler.timesteps[i+1]
            s = torch.tensor([s] * len(latents)).to(device)

            latents = q_sample(x=x_0, t=s, scheduler=pipe.scheduler)
        else:
            # Последний шаг
            latents = x_0

        latents = latents.half()

    image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
    do_denormalize = [True] * image.shape[0]
    image = pipe.image_processor.postprocess(image, output_type="pil", do_denormalize=do_denormalize)
    return image

Попробуем сгененировать что-то нашей моделью. Можно поиграться с разными сидами и гайденс скейлами.

Референс, что примерно должно получиться на этом этапе для guidance_scale=2. Как видите, картинки стали почетче, но пока все еще так себе.

img

In [ ]:
pipe.unet = cm_unet.eval().to(torch.float16)
assert cm_unet.active_adapter == 'ct'

generator = torch.Generator(device="cuda").manual_seed(1)
guidance_scale = 3

# Заменяем генерацию пайплайном на наше сэмплирование.
images = consistency_sampling(
    pipe,
    prompt,
    num_inference_steps=4,
    generator=generator,
    num_images_per_prompt=4,
    guidance_scale=guidance_scale
)

visualize_images(images)
/usr/local/lib/python3.10/dist-packages/peft/tuners/lora/model.py:364: FutureWarning: Accessing config attribute `in_channels` directly via 'UNet2DConditionModel' object attribute is deprecated. Please access 'in_channels' over 'UNet2DConditionModel's config object instead, e.g. 'unet.config.in_channels'.
  return getattr(self.model, name)
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image

Consistency Distillation¶

Задание №6¶

Теперь давайте попробуем перейти к постановке дистилляции, где шаг из $x_t$ в $x_s$ будет делаться не аналитически, а c помощью модели учителя.

$\mathbf{x}_t = q(\mathbf{x}_t | \mathbf{x}_0)$

$\mathbf{x}_s = DDIM(\epsilon_\theta(\mathbf{x}_t, t), \mathbf{x}_t, t, s)$

Замечание: В text-to-image генерации classifier-free guidance (CFG) играет очень важную роль для получения хорошего качества с помощью диффузии. CFG меняет траектории ODE и раз нам он важен, то давайте и дистиллировать траектории с CFG.

Поэтому для получения точки $\mathbf{x}_{s}$ мы будем использовать шаг учителя с CFG. Это важное отличие от CT сеттинга - там мы не можем моделировать гайденс.

In [ ]:
unet = unet.to(torch.float32)
unet.train()
assert unet.dtype == torch.float32

# Добавляем новые LoRA адаптеры для CD модели
cm_unet.add_adapter("cd", lora_config)
cm_unet.set_adapter("cd")

# Пересоздаем оптимизатор
optimizer = torch.optim.AdamW(cm_unet.parameters(), lr=1e-4)
In [ ]:
@torch.no_grad()
def get_xs_from_xt_with_teacher(
    x_0, x_t, t, s, # Не все эти аргументы могут быть вам нужны
    scheduler,
    prompt_embeds,
    teacher_unet,
    guidance_scale,
    **kwargs
):
    # Делаем предсказание учителем в кондишион случае: подаем эмбеды текста
    cond_noise_pred = teacher_unet(
        x_t, t, encoder_hidden_states=prompt_embeds
    ).sample

    # Для CFG нам нужно делать предсказания в unconditional случае.
    # Для T2I моделей, мы будем это моделировать предсказаниями для пустого промпта ""
    # Извлечем эмбеды из пустого промпта и размножить их до размера батча
    uncond_input_ids = pipe.tokenizer(
        [""], return_tensors="pt", padding="max_length", max_length=77
    ).input_ids.to("cuda")

    uncond_prompt_embeds = pipe.text_encoder(uncond_input_ids)[0].expand(
        *prompt_embeds.shape
    )

    # Затем прогоняем модель для пустых промптов
    uncond_noise_pred = teacher_unet(
        x_t, t, encoder_hidden_states=uncond_prompt_embeds
    ).sample

    # Применяем CFG формулу и получаем итоговый предикт учителя
    noise_pred = uncond_noise_pred + guidance_scale * (cond_noise_pred - uncond_noise_pred)

    # Получаем x_s из x_t
    x_s = ddim_solver_step(
        model_output=noise_pred, x_t=x_t, t=t, s=s, scheduler=scheduler
    )

    return x_s
In [ ]:
# Сразу зададим внутрь модель учителя и guidance_scale
get_xs_from_xt_with_teacher = functools.partial(
    get_xs_from_xt_with_teacher,
    teacher_unet=teacher_unet,
    guidance_scale=7.5
)

Еще, как показано в работе Improved Techniques for Training Consistency Models. L2 лосс не самый оптимальный выбор для консистенси моделей. Давайте в CD обучении также заменим MSE лосс на pseudo-huber лосс из статьи.

In [ ]:
def pseudo_huber_loss(
    x: torch.Tensor,
    y: torch.Tensor,
    c=0.001
):
    diff = x - y
    loss = torch.sum(torch.sqrt(diff**2 + c**2) - c)
    return loss
In [ ]:
cd_loss = functools.partial(
    cm_loss_template,

    loss_fn=pseudo_huber_loss,
    get_boundary_timesteps=get_zero_boundary_timesteps,
    get_xs_from_xt=get_xs_from_xt_with_teacher
)

assert cm_unet.active_adapter == 'cd'

Теперь обучим модель в CD режиме

Лосс большой поскольку не добавил в huber loss усреднение, но картинки получаются все равно хорошего качества!!!¶

In [ ]:
num_grad_accum = 2 # обновляем параметры каждые 2 шага

train_loop(cm_unet, pipe, train_dataloader, optimizer, cd_loss, num_grad_accum)
<ipython-input-17-04d1c509c8fc>:7: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
  scaler = GradScaler()
  0%|          | 0/625 [00:00<?, ?it/s]
<ipython-input-17-04d1c509c8fc>:15: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with autocast(dtype=torch.float16):  # Включаем mixed precision
Step 1, Loss: 3067.6396484375
Step 2, Loss: 4501.21533203125
Step 3, Loss: 3528.36181640625
Step 4, Loss: 2838.5224609375
Step 5, Loss: 3635.279296875
Step 6, Loss: 3245.017822265625
Step 7, Loss: 4237.431640625
Step 8, Loss: 3144.990966796875
Step 9, Loss: 4142.67529296875
Step 10, Loss: 3631.2841796875
Step 11, Loss: 3534.72998046875
Step 12, Loss: 1901.096923828125
Step 13, Loss: 2335.93798828125
Step 14, Loss: 2924.720703125
Step 15, Loss: 3095.15673828125
Step 16, Loss: 2338.27099609375
Step 17, Loss: 2354.35693359375
Step 18, Loss: 2751.059814453125
Step 19, Loss: 2736.702392578125
Step 20, Loss: 2632.9658203125
Step 21, Loss: 3111.861328125
Step 22, Loss: 5292.263671875
Step 23, Loss: 2556.02294921875
Step 24, Loss: 2551.59814453125
Step 25, Loss: 3263.38623046875
Step 26, Loss: 2465.36083984375
Step 27, Loss: 2000.357666015625
Step 28, Loss: 2800.03466796875
Step 29, Loss: 2182.60546875
Step 30, Loss: 5640.53857421875
Step 31, Loss: 4463.96875
Step 32, Loss: 5104.5107421875
Step 33, Loss: 2933.88916015625
Step 34, Loss: 3778.8603515625
Step 35, Loss: 4359.3681640625
Step 36, Loss: 3861.1748046875
Step 37, Loss: 3387.4375
Step 38, Loss: 3801.18701171875
Step 39, Loss: 3751.1162109375
Step 40, Loss: 3574.050048828125
Step 41, Loss: 2936.130859375
Step 42, Loss: 2676.213623046875
Step 43, Loss: 3132.67138671875
Step 44, Loss: 4475.865234375
Step 45, Loss: 3380.264892578125
Step 46, Loss: 4189.265625
Step 47, Loss: 3740.886962890625
Step 48, Loss: 2627.85498046875
Step 49, Loss: 2620.080322265625
Step 50, Loss: 2757.20849609375
Step 51, Loss: 2848.977783203125
Step 52, Loss: 3051.163330078125
Step 53, Loss: 2933.139892578125
Step 54, Loss: 3533.27294921875
Step 55, Loss: 4150.51806640625
Step 56, Loss: 2986.15283203125
Step 57, Loss: 3702.853515625
Step 58, Loss: 3069.2958984375
Step 59, Loss: 3301.21435546875
Step 60, Loss: 4020.974853515625
Step 61, Loss: 3972.615966796875
Step 62, Loss: 3032.62890625
Step 63, Loss: 3643.7158203125
Step 64, Loss: 4450.271484375
Step 65, Loss: 3761.891357421875
Step 66, Loss: 3120.98046875
Step 67, Loss: 4099.08447265625
Step 68, Loss: 3424.960693359375
Step 69, Loss: 3724.317138671875
Step 70, Loss: 5375.58984375
Step 71, Loss: 5073.2177734375
Step 72, Loss: 3852.9150390625
Step 73, Loss: 2907.05078125
Step 74, Loss: 3241.89794921875
Step 75, Loss: 5088.7919921875
Step 76, Loss: 3307.97802734375
Step 77, Loss: 3389.269287109375
Step 78, Loss: 4032.13720703125
Step 79, Loss: 4095.20703125
Step 80, Loss: 2576.86669921875
Step 81, Loss: 6043.5419921875
Step 82, Loss: 5501.4296875
Step 83, Loss: 6051.47021484375
Step 84, Loss: 4081.576171875
Step 85, Loss: 3376.9384765625
Step 86, Loss: 3990.83251953125
Step 87, Loss: 3614.535888671875
Step 88, Loss: 6747.9072265625
Step 89, Loss: 5914.2421875
Step 90, Loss: 4253.474609375
Step 91, Loss: 6455.4697265625
Step 92, Loss: 4305.9140625
Step 93, Loss: 8737.1953125
Step 94, Loss: 6071.18212890625
Step 95, Loss: 5942.138671875
Step 96, Loss: 4217.1015625
Step 97, Loss: 2322.83642578125
Step 98, Loss: 3829.908935546875
Step 99, Loss: 3482.105224609375
Step 100, Loss: 5649.88330078125
Step 101, Loss: 9304.447265625
Step 102, Loss: 5281.0556640625
Step 103, Loss: 5694.24951171875
Step 104, Loss: 7078.431640625
Step 105, Loss: 4312.2626953125
Step 106, Loss: 3901.015869140625
Step 107, Loss: 4590.02978515625
Step 108, Loss: 7480.7421875
Step 109, Loss: 6370.4169921875
Step 110, Loss: 7317.55419921875
Step 111, Loss: 4277.236328125
Step 112, Loss: 7304.1455078125
Step 113, Loss: 12386.59765625
Step 114, Loss: 5439.9638671875
Step 115, Loss: 4104.0244140625
Step 116, Loss: 3760.8095703125
Step 117, Loss: 6188.7431640625
Step 118, Loss: 10398.21875
Step 119, Loss: 6285.6279296875
Step 120, Loss: 6792.962890625
Step 121, Loss: 7379.0625
Step 122, Loss: 8344.27734375
Step 123, Loss: 7244.67529296875
Step 124, Loss: 12012.86328125
Step 125, Loss: 7564.287109375
Step 126, Loss: 3739.274658203125
Step 127, Loss: 4759.7197265625
Step 128, Loss: 4001.96142578125
Step 129, Loss: 4556.54443359375
Step 130, Loss: 9362.6318359375
Step 131, Loss: 6356.1318359375
Step 132, Loss: 4285.30712890625
Step 133, Loss: 7411.22314453125
Step 134, Loss: 10271.453125
Step 135, Loss: 5382.138671875
Step 136, Loss: 8593.34765625
Step 137, Loss: 7603.85986328125
Step 138, Loss: 11660.451171875
Step 139, Loss: 6245.08642578125
Step 140, Loss: 5117.90087890625
Step 141, Loss: 6229.255859375
Step 142, Loss: 7945.4853515625
Step 143, Loss: 6740.63916015625
Step 144, Loss: 5087.9970703125
Step 145, Loss: 9219.767578125
Step 146, Loss: 6910.55419921875
Step 147, Loss: 8740.2607421875
Step 148, Loss: 4533.8310546875
Step 149, Loss: 7340.5283203125
Step 150, Loss: 6660.20458984375
Step 151, Loss: 5197.2548828125
Step 152, Loss: 7767.736328125
Step 153, Loss: 5133.13232421875
Step 154, Loss: 7437.169921875
Step 155, Loss: 7853.3349609375
Step 156, Loss: 9258.6396484375
Step 157, Loss: 7418.76318359375
Step 158, Loss: 6626.232421875
Step 159, Loss: 3714.283447265625
Step 160, Loss: 5961.29248046875
Step 161, Loss: 5662.3291015625
Step 162, Loss: 6956.33642578125
Step 163, Loss: 8190.3017578125
Step 164, Loss: 8912.3173828125
Step 165, Loss: 5106.20068359375
Step 166, Loss: 5819.31298828125
Step 167, Loss: 7879.0009765625
Step 168, Loss: 3761.307861328125
Step 169, Loss: 5512.0107421875
Step 170, Loss: 9456.24609375
Step 171, Loss: 9358.2021484375
Step 172, Loss: 6571.73291015625
Step 173, Loss: 5759.2841796875
Step 174, Loss: 4750.193359375
Step 175, Loss: 3060.482666015625
Step 176, Loss: 3645.848388671875
Step 177, Loss: 5293.1494140625
Step 178, Loss: 6402.9169921875
Step 179, Loss: 6034.236328125
Step 180, Loss: 6270.8154296875
Step 181, Loss: 4819.3349609375
Step 182, Loss: 8022.818359375
Step 183, Loss: 13358.484375
Step 184, Loss: 7847.2744140625
Step 185, Loss: 8682.337890625
Step 186, Loss: 5131.509765625
Step 187, Loss: 4247.0625
Step 188, Loss: 9087.06640625
Step 189, Loss: 8091.3076171875
Step 190, Loss: 4227.57958984375
Step 191, Loss: 6432.40087890625
Step 192, Loss: 5576.962890625
Step 193, Loss: 1886.610595703125
Step 194, Loss: 9776.810546875
Step 195, Loss: 5033.0546875
Step 196, Loss: 9555.078125
Step 197, Loss: 6725.07958984375
Step 198, Loss: 2349.116455078125
Step 199, Loss: 6824.6591796875
Step 200, Loss: 3863.759765625
Step 201, Loss: 5391.271484375
Step 202, Loss: 3628.31787109375
Step 203, Loss: 4978.67431640625
Step 204, Loss: 6335.658203125
Step 205, Loss: 4753.02001953125
Step 206, Loss: 4683.0947265625
Step 207, Loss: 2700.89111328125
Step 208, Loss: 8061.462890625
Step 209, Loss: 8857.8671875
Step 210, Loss: 8114.0439453125
Step 211, Loss: 5542.8154296875
Step 212, Loss: 6195.6015625
Step 213, Loss: 7623.865234375
Step 214, Loss: 6993.47509765625
Step 215, Loss: 5519.9560546875
Step 216, Loss: 6307.55908203125
Step 217, Loss: 7700.26806640625
Step 218, Loss: 5104.650390625
Step 219, Loss: 7564.41845703125
Step 220, Loss: 8365.9931640625
Step 221, Loss: 3689.85302734375
Step 222, Loss: 7112.26416015625
Step 223, Loss: 4709.306640625
Step 224, Loss: 7390.650390625
Step 225, Loss: 5056.04931640625
Step 226, Loss: 8604.79296875
Step 227, Loss: 6731.34716796875
Step 228, Loss: 6824.1826171875
Step 229, Loss: 6225.33837890625
Step 230, Loss: 4657.8828125
Step 231, Loss: 5898.32373046875
Step 232, Loss: 8481.3466796875
Step 233, Loss: 6476.8310546875
Step 234, Loss: 11240.8056640625
Step 235, Loss: 7373.90771484375
Step 236, Loss: 5295.42431640625
Step 237, Loss: 4240.84375
Step 238, Loss: 8358.541015625
Step 239, Loss: 4117.34228515625
Step 240, Loss: 8481.828125
Step 241, Loss: 10369.083984375
Step 242, Loss: 4568.6171875
Step 243, Loss: 4461.9072265625
Step 244, Loss: 7004.130859375
Step 245, Loss: 5408.955078125
Step 246, Loss: 8705.5478515625
Step 247, Loss: 6420.23388671875
Step 248, Loss: 5777.34716796875
Step 249, Loss: 5146.31494140625
Step 250, Loss: 3864.042724609375
Step 251, Loss: 6520.03759765625
Step 252, Loss: 6385.6064453125
Step 253, Loss: 5411.78369140625
Step 254, Loss: 7851.2548828125
Step 255, Loss: 6904.91552734375
Step 256, Loss: 2874.87451171875
Step 257, Loss: 5579.52587890625
Step 258, Loss: 5122.4072265625
Step 259, Loss: 10255.5693359375
Step 260, Loss: 2960.11669921875
Step 261, Loss: 6385.0810546875
Step 262, Loss: 2329.3837890625
Step 263, Loss: 8673.728515625
Step 264, Loss: 5706.6884765625
Step 265, Loss: 6505.4814453125
Step 266, Loss: 10179.833984375
Step 267, Loss: 1975.400634765625
Step 268, Loss: 13108.318359375
Step 269, Loss: 6965.330078125
Step 270, Loss: 5480.1484375
Step 271, Loss: 4018.4619140625
Step 272, Loss: 4464.1796875
Step 273, Loss: 8933.349609375
Step 274, Loss: 8485.5166015625
Step 275, Loss: 5496.9736328125
Step 276, Loss: 7384.4560546875
Step 277, Loss: 4863.380859375
Step 278, Loss: 7213.10009765625
Step 279, Loss: 5356.1591796875
Step 280, Loss: 7484.4326171875
Step 281, Loss: 6442.62255859375
Step 282, Loss: 8255.208984375
Step 283, Loss: 6983.66064453125
Step 284, Loss: 4387.44921875
Step 285, Loss: 7950.56689453125
Step 286, Loss: 4250.01953125
Step 287, Loss: 6698.52392578125
Step 288, Loss: 4678.005859375
Step 289, Loss: 5268.7900390625
Step 290, Loss: 7555.23291015625
Step 291, Loss: 7431.2294921875
Step 292, Loss: 1926.84765625
Step 293, Loss: 7947.033203125
Step 294, Loss: 7582.3505859375
Step 295, Loss: 4912.58935546875
Step 296, Loss: 3210.071044921875
Step 297, Loss: 4715.45947265625
Step 298, Loss: 11928.0234375
Step 299, Loss: 2795.9111328125
Step 300, Loss: 2382.36865234375
Step 301, Loss: 3922.52587890625
Step 302, Loss: 6586.05859375
Step 303, Loss: 4927.994140625
Step 304, Loss: 4698.634765625
Step 305, Loss: 5751.0263671875
Step 306, Loss: 4140.193359375
Step 307, Loss: 5998.912109375
Step 308, Loss: 8095.69873046875
Step 309, Loss: 4656.4951171875
Step 310, Loss: 5908.8486328125
Step 311, Loss: 5935.57275390625
Step 312, Loss: 6386.28125
Step 313, Loss: 6940.20068359375
Step 314, Loss: 9142.6171875
Step 315, Loss: 2590.2939453125
Step 316, Loss: 6724.3203125
Step 317, Loss: 3780.9267578125
Step 318, Loss: 5891.82373046875
Step 319, Loss: 5028.5029296875
Step 320, Loss: 5054.27783203125
Step 321, Loss: 9260.90625
Step 322, Loss: 4988.005859375
Step 323, Loss: 4460.4697265625
Step 324, Loss: 5937.76708984375
Step 325, Loss: 5603.0947265625
Step 326, Loss: 5106.73291015625
Step 327, Loss: 6049.591796875
Step 328, Loss: 7533.95947265625
Step 329, Loss: 3525.634765625
Step 330, Loss: 5791.76416015625
Step 331, Loss: 8983.359375
Step 332, Loss: 4312.0771484375
Step 333, Loss: 10512.8359375
Step 334, Loss: 4606.330078125
Step 335, Loss: 3263.814453125
Step 336, Loss: 4092.023681640625
Step 337, Loss: 6944.146484375
Step 338, Loss: 4410.89111328125
Step 339, Loss: 7504.6416015625
Step 340, Loss: 3232.72412109375
Step 341, Loss: 5564.2734375
Step 342, Loss: 8555.3056640625
Step 343, Loss: 6968.38037109375
Step 344, Loss: 9392.2421875
Step 345, Loss: 4656.8740234375
Step 346, Loss: 5977.08203125
Step 347, Loss: 3161.54931640625
Step 348, Loss: 8505.0556640625
Step 349, Loss: 4393.68896484375
Step 350, Loss: 6235.12353515625
Step 351, Loss: 5617.68798828125
Step 352, Loss: 5996.26416015625
Step 353, Loss: 3228.28857421875
Step 354, Loss: 3934.18798828125
Step 355, Loss: 6701.453125
Step 356, Loss: 3992.371826171875
Step 357, Loss: 4660.65625
Step 358, Loss: 4088.6162109375
Step 359, Loss: 7012.5205078125
Step 360, Loss: 5136.60400390625
Step 361, Loss: 5305.50244140625
Step 362, Loss: 5424.4150390625
Step 363, Loss: 7307.763671875
Step 364, Loss: 7429.21875
Step 365, Loss: 6045.5986328125
Step 366, Loss: 4839.5751953125
Step 367, Loss: 4386.4248046875
Step 368, Loss: 3612.373046875
Step 369, Loss: 4420.654296875
Step 370, Loss: 1770.496337890625
Step 371, Loss: 4422.84326171875
Step 372, Loss: 4809.2705078125
Step 373, Loss: 4419.31494140625
Step 374, Loss: 6618.95068359375
Step 375, Loss: 5547.0146484375
Step 376, Loss: 6225.068359375
Step 377, Loss: 6375.6171875
Step 378, Loss: 2733.548583984375
Step 379, Loss: 5085.37646484375
Step 380, Loss: 1482.4791259765625
Step 381, Loss: 5116.40625
Step 382, Loss: 7516.4189453125
Step 383, Loss: 7153.5146484375
Step 384, Loss: 5548.19921875
Step 385, Loss: 6198.091796875
Step 386, Loss: 2885.35498046875
Step 387, Loss: 9088.642578125
Step 388, Loss: 10132.908203125
Step 389, Loss: 6325.3017578125
Step 390, Loss: 6302.7333984375
Step 391, Loss: 9528.0703125
Step 392, Loss: 4928.8505859375
Step 393, Loss: 6787.09619140625
Step 394, Loss: 5132.1337890625
Step 395, Loss: 7338.17529296875
Step 396, Loss: 4688.7548828125
Step 397, Loss: 5553.365234375
Step 398, Loss: 5271.1220703125
Step 399, Loss: 7157.4794921875
Step 400, Loss: 6882.8994140625
Step 401, Loss: 3169.628173828125
Step 402, Loss: 5126.3876953125
Step 403, Loss: 2943.015625
Step 404, Loss: 6440.2275390625
Step 405, Loss: 6651.93359375
Step 406, Loss: 4473.77001953125
Step 407, Loss: 7756.29296875
Step 408, Loss: 5183.6806640625
Step 409, Loss: 6820.4404296875
Step 410, Loss: 5539.83642578125
Step 411, Loss: 3924.71337890625
Step 412, Loss: 8651.986328125
Step 413, Loss: 6247.2685546875
Step 414, Loss: 8581.720703125
Step 415, Loss: 4809.62744140625
Step 416, Loss: 9160.2060546875
Step 417, Loss: 7738.10986328125
Step 418, Loss: 6360.07958984375
Step 419, Loss: 6054.931640625
Step 420, Loss: 7206.0810546875
Step 421, Loss: 5747.8359375
Step 422, Loss: 4190.154296875
Step 423, Loss: 8808.7548828125
Step 424, Loss: 10492.1015625
Step 425, Loss: 5053.8681640625
Step 426, Loss: 7452.0703125
Step 427, Loss: 4237.7578125
Step 428, Loss: 2869.2587890625
Step 429, Loss: 4189.037109375
Step 430, Loss: 5199.412109375
Step 431, Loss: 6138.0419921875
Step 432, Loss: 5768.3564453125
Step 433, Loss: 4534.21533203125
Step 434, Loss: 9090.6171875
Step 435, Loss: 6423.3310546875
Step 436, Loss: 7451.03369140625
Step 437, Loss: 4777.95654296875
Step 438, Loss: 6288.56494140625
Step 439, Loss: 7394.0986328125
Step 440, Loss: 4003.29052734375
Step 441, Loss: 6014.408203125
Step 442, Loss: 5068.107421875
Step 443, Loss: 5231.33984375
Step 444, Loss: 4120.45263671875
Step 445, Loss: 3309.422607421875
Step 446, Loss: 6272.455078125
Step 447, Loss: 7732.92919921875
Step 448, Loss: 4990.6318359375
Step 449, Loss: 3847.4287109375
Step 450, Loss: 9725.22265625
Step 451, Loss: 4568.7822265625
Step 452, Loss: 6251.55078125
Step 453, Loss: 4916.6533203125
Step 454, Loss: 6205.8486328125
Step 455, Loss: 3742.043701171875
Step 456, Loss: 6726.53515625
Step 457, Loss: 7195.6337890625
Step 458, Loss: 3889.393798828125
Step 459, Loss: 6712.830078125
Step 460, Loss: 6117.98828125
Step 461, Loss: 4891.23681640625
Step 462, Loss: 4433.759765625
Step 463, Loss: 7305.0263671875
Step 464, Loss: 5478.650390625
Step 465, Loss: 5166.064453125
Step 466, Loss: 3683.619873046875
Step 467, Loss: 5027.83056640625
Step 468, Loss: 4194.958984375
Step 469, Loss: 5198.7314453125
Step 470, Loss: 2188.5341796875
Step 471, Loss: 4224.28759765625
Step 472, Loss: 6546.12890625
Step 473, Loss: 8023.783203125
Step 474, Loss: 5774.46875
Step 475, Loss: 11540.9658203125
Step 476, Loss: 4529.5849609375
Step 477, Loss: 7224.9892578125
Step 478, Loss: 7143.7255859375
Step 479, Loss: 4003.281005859375
Step 480, Loss: 4674.6953125
Step 481, Loss: 4016.17138671875
Step 482, Loss: 5501.24169921875
Step 483, Loss: 6691.3173828125
Step 484, Loss: 5257.6669921875
Step 485, Loss: 4468.6845703125
Step 486, Loss: 6899.38671875
Step 487, Loss: 7834.3291015625
Step 488, Loss: 2152.55078125
Step 489, Loss: 3332.0595703125
Step 490, Loss: 5432.7578125
Step 491, Loss: 5623.12939453125
Step 492, Loss: 4124.57568359375
Step 493, Loss: 2991.28955078125
Step 494, Loss: 2484.10107421875
Step 495, Loss: 3913.890869140625
Step 496, Loss: 3984.5537109375
Step 497, Loss: 6917.552734375
Step 498, Loss: 5427.06201171875
Step 499, Loss: 6241.8974609375
Step 500, Loss: 6371.95849609375
Step 501, Loss: 6826.9833984375
Step 502, Loss: 4729.8173828125
Step 503, Loss: 8890.791015625
Step 504, Loss: 4798.78515625
Step 505, Loss: 7523.17578125
Step 506, Loss: 6336.71826171875
Step 507, Loss: 5609.0830078125
Step 508, Loss: 7567.03125
Step 509, Loss: 5495.232421875
Step 510, Loss: 8643.080078125
Step 511, Loss: 8627.75
Step 512, Loss: 4278.63134765625
Step 513, Loss: 2802.3037109375
Step 514, Loss: 5261.93017578125
Step 515, Loss: 7480.02734375
Step 516, Loss: 4462.88671875
Step 517, Loss: 4688.5537109375
Step 518, Loss: 5175.33984375
Step 519, Loss: 5336.74609375
Step 520, Loss: 4055.109375
Step 521, Loss: 8050.234375
Step 522, Loss: 4862.365234375
Step 523, Loss: 2215.04345703125
Step 524, Loss: 6717.880859375
Step 525, Loss: 7073.34619140625
Step 526, Loss: 5147.61962890625
Step 527, Loss: 5916.01513671875
Step 528, Loss: 9250.62109375
Step 529, Loss: 4330.779296875
Step 530, Loss: 6731.103515625
Step 531, Loss: 5912.619140625
Step 532, Loss: 6973.14501953125
Step 533, Loss: 5659.2822265625
Step 534, Loss: 6368.8603515625
Step 535, Loss: 2829.69677734375
Step 536, Loss: 4062.171875
Step 537, Loss: 8312.2265625
Step 538, Loss: 7402.71923828125
Step 539, Loss: 3628.8916015625
Step 540, Loss: 8412.5283203125
Step 541, Loss: 8000.80126953125
Step 542, Loss: 7564.3359375
Step 543, Loss: 8388.765625
Step 544, Loss: 6130.25390625
Step 545, Loss: 2406.29541015625
Step 546, Loss: 6641.8291015625
Step 547, Loss: 11002.904296875
Step 548, Loss: 7771.71240234375
Step 549, Loss: 5968.1435546875
Step 550, Loss: 9708.85546875
Step 551, Loss: 9129.2890625
Step 552, Loss: 5903.83251953125
Step 553, Loss: 5543.8310546875
Step 554, Loss: 6419.64892578125
Step 555, Loss: 2232.00830078125
Step 556, Loss: 5184.1640625
Step 557, Loss: 5963.85693359375
Step 558, Loss: 4534.2802734375
Step 559, Loss: 5421.09228515625
Step 560, Loss: 3904.07958984375
Step 561, Loss: 6978.623046875
Step 562, Loss: 1461.27099609375
Step 563, Loss: 12069.4833984375
Step 564, Loss: 7750.2607421875
Step 565, Loss: 7577.3076171875
Step 566, Loss: 3827.16259765625
Step 567, Loss: 5274.05322265625
Step 568, Loss: 7263.1240234375
Step 569, Loss: 6747.3330078125
Step 570, Loss: 4211.787109375
Step 571, Loss: 4571.31787109375
Step 572, Loss: 5386.0546875
Step 573, Loss: 7221.90087890625
Step 574, Loss: 3458.982177734375
Step 575, Loss: 7115.873046875
Step 576, Loss: 8545.783203125
Step 577, Loss: 6167.5146484375
Step 578, Loss: 5334.92333984375
Step 579, Loss: 4170.1953125
Step 580, Loss: 4024.8974609375
Step 581, Loss: 4247.57666015625
Step 582, Loss: 5386.41552734375
Step 583, Loss: 3921.037109375
Step 584, Loss: 5623.51171875
Step 585, Loss: 7213.2265625
Step 586, Loss: 5052.44921875
Step 587, Loss: 4799.7138671875
Step 588, Loss: 9004.5625
Step 589, Loss: 4952.087890625
Step 590, Loss: 6750.4462890625
Step 591, Loss: 5455.04931640625
Step 592, Loss: 7526.740234375
Step 593, Loss: 2873.11083984375
Step 594, Loss: 2919.0390625
Step 595, Loss: 4697.41015625
Step 596, Loss: 4942.58935546875
Step 597, Loss: 3529.548828125
Step 598, Loss: 5228.578125
Step 599, Loss: 5086.94970703125
Step 600, Loss: 6321.24951171875
Step 601, Loss: 2977.922119140625
Step 602, Loss: 4450.2197265625
Step 603, Loss: 4080.992431640625
Step 604, Loss: 4699.5341796875
Step 605, Loss: 2120.114990234375
Step 606, Loss: 4223.296875
Step 607, Loss: 4961.94775390625
Step 608, Loss: 7770.8603515625
Step 609, Loss: 9213.3515625
Step 610, Loss: 4143.97509765625
Step 611, Loss: 3693.976806640625
Step 612, Loss: 4183.45361328125
Step 613, Loss: 6152.4052734375
Step 614, Loss: 6194.6533203125
Step 615, Loss: 6285.8759765625
Step 616, Loss: 5739.7294921875
Step 617, Loss: 7159.712890625
Step 618, Loss: 3027.54638671875
Step 619, Loss: 6115.3916015625
Step 620, Loss: 3153.59716796875
Step 621, Loss: 5165.57763671875
Step 622, Loss: 4092.1318359375
Step 623, Loss: 6564.84033203125
Step 624, Loss: 8872.041015625
Step 625, Loss: 4565.2060546875
In [ ]:
# torch.save(cm_unet.state_dict(), '/content/drive/MyDrive/cv_model/cd_model.pth')

Снова сэмплируем¶

Обратим внимание, что тут мы сэмпилруем без гайденса, потому что мы его уже частично прокинули в модель, когда делали шаг учителя с CFG.

Снова для референса приводим картинки на этом этапе:

img

Ваши картинки не обязаны совпадать: у вас могут быть немного менее/более качественные. Небольшая разница по качеству на оценку не влиет.

In [ ]:
# Подставляем нашу новую обученную модель в пайплайн
pipe.unet = cm_unet.eval().to(torch.float16)
assert cm_unet.active_adapter == 'cd'

generator = torch.Generator(device="cuda").manual_seed(0)
guidance_scale = 0

images = consistency_sampling(
    pipe,
    prompt,
    num_inference_steps=4,
    generator=generator,
    num_images_per_prompt=4,
    guidance_scale=guidance_scale
)

visualize_images(images)
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image

Давайте посмотрим на картинки для других промптов¶

In [12]:
validation_prompts = [
    "A sad puppy with large eyes",
    "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
    "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
    "A girl with pale blue hair and a cami tank top",
    "A lighthouse in a giant wave, origami style",
    "belle epoque, christmas, red house in the forest, photo realistic, 8k",
    "A small cactus with a happy face in the Sahara desert",
    "Green commercial building with refrigerator and refrigeration units outside",
]
In [ ]:
for prompt in validation_prompts:
    generator = torch.Generator(device="cuda").manual_seed(0)

    images = consistency_sampling(
        pipe,
        prompt,
        num_inference_steps=4,
        generator=generator,
        num_images_per_prompt=4,
        guidance_scale=guidance_scale
    )

    visualize_images(images)
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

Multi-boundary Сonsistency Distillation¶

No description has been provided for this image

В конце мы рассмотрим недавнюю модификацию CD, Multi-boundary CD, где интегрируем не всю траекторию сразу и потом сэмплируем с возвращением назад, а разбиваем траектории на $K$ отрезков и применяет CD внутри каждого отрезка независимо. Например, на картинке выше у нас два отрезка: зеленым и красным выделены две граничные точки. Для классического CD, рассмотренного ранее, у нас только одна граничная точка в $t = 0$

Обратим внимание, что сэмплирование становится детерминистичным и можно снова использовать DDIM солвер, где число шагов равно числу интервалов $K$, на которые мы разбили траектории во время обучения.

Этот метод гораздо лучше работает чем обычный CD, потому что решать задачу CD на отрезках, а не на всей траектории, гораздо проще. В текущем задании мы разобьем траекторию на $K=4$ отрезка.

Подробнее почитать можно в этой статье.

Задание №7 (0.5 балла, сдается в контесте)¶

Ниже реализуйте функцию, которая для $K=4$ отрезков будет сопоставлять таймстепам соответствующие граничные точки.

Например, для $K=2$ отрезков граничные точки будут: [0, 499]

$0 \leq t < 499$ -> граничная точка - $0$

$499 \leq t < 999$ -> граничная точка - $499$

Ресурсы в колабе закончились, обучал последнюю модель в kaggle¶

In [ ]:
# Восстанавливаем модель UNet и оборачиваем её PEFT
unet = UNet2DConditionModel.from_pretrained(
    'sd-legacy/stable-diffusion-v1-5',
    subfolder='unet',
    torch_dtype=torch.float32,
).to('cuda')
In [ ]:
cm_unet = get_peft_model(unet, lora_config, adapter_name="multi-cd")
cm_unet.enable_gradient_checkpointing()
In [ ]:
# Создаем новый адаптер
unet = unet.to(torch.float32)
unet.train()
assert unet.dtype == torch.float32

# Добавляем новый адаптер "multi-cd"
cm_unet.add_adapter("multi-cd", lora_config)
cm_unet.set_adapter("multi-cd")

# Пересоздаём оптимизатор
optimizer = torch.optim.AdamW(cm_unet.parameters(), lr=1e-4)
In [ ]:
def get_multi_boundary_timesteps(
    timesteps,
    num_boundaries=4,
    num_timesteps=1000,
):
    """
    Для батча таймстепов определяем соответствующие граничные точки.
    params:
        timesteps: torch.Tensor(batch_size, device='cuda')
    returns:
        boundary_timesteps: torch.Tensor(batch_size, device='cuda')
    """
    # Здесь важно аккуратно поработать с таймстепами,
    # чтобы не перелетать граничные точки и при этом иногда попадать в них.
    # Совет: повыводить timesteps и boundary_timesteps перед обучением.

    boundaries = torch.linspace(0, num_timesteps - 1, num_boundaries + 1, device=timesteps.device, dtype=torch.long)

    indices = torch.bucketize(timesteps, boundaries, right=False)

    boundary_timesteps = boundaries[torch.clamp(indices - 1, min=0)]

    return boundary_timesteps
In [ ]:
multi_cd_loss = functools.partial(
    cm_loss_template,

    loss_fn=pseudo_huber_loss,
    get_boundary_timesteps=get_multi_boundary_timesteps,
    get_xs_from_xt=get_xs_from_xt_with_teacher
)
assert cm_unet.active_adapter == 'multi-cd'

Теперь обучим Multi-boundary CD модель

In [ ]:
torch.cuda.empty_cache()
In [ ]:
num_grad_accum = 2 # обновляем параметры каждые 2 шага

train_loop(cm_unet, pipe, train_dataloader, optimizer, multi_cd_loss, num_grad_accum)
/tmp/ipykernel_23/4028741595.py:7: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
  scaler = GradScaler()
  0%|          | 0/625 [00:00<?, ?it/s]
/tmp/ipykernel_23/4028741595.py:15: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with autocast(dtype=torch.float16):  # Включаем mixed precision
/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
Step 1, Loss: 1412.8082275390625
Step 2, Loss: 941.15087890625
Step 3, Loss: 1324.7342529296875
Step 4, Loss: 1476.2618408203125
Step 5, Loss: 1478.176513671875
Step 6, Loss: 1110.04052734375
Step 7, Loss: 1325.424560546875
Step 8, Loss: 1190.1014404296875
Step 9, Loss: 1061.3260498046875
Step 10, Loss: 1054.9925537109375
Step 11, Loss: 1480.53125
Step 12, Loss: 1221.750732421875
Step 13, Loss: 1454.07177734375
Step 14, Loss: 1729.9173583984375
Step 15, Loss: 1381.732666015625
Step 16, Loss: 1263.1090087890625
Step 17, Loss: 1143.2685546875
Step 18, Loss: 981.5584106445312
Step 19, Loss: 1296.2276611328125
Step 20, Loss: 1196.6866455078125
Step 21, Loss: 1264.32666015625
Step 22, Loss: 1215.798583984375
Step 23, Loss: 1191.38671875
Step 24, Loss: 1388.1199951171875
Step 25, Loss: 1487.18896484375
Step 26, Loss: 1353.6796875
Step 27, Loss: 1160.75732421875
Step 28, Loss: 1861.5467529296875
Step 29, Loss: 1208.218017578125
Step 30, Loss: 1362.111572265625
Step 31, Loss: 1071.198974609375
Step 32, Loss: 1252.173095703125
Step 33, Loss: 1355.49072265625
Step 34, Loss: 1508.05126953125
Step 35, Loss: 1267.271240234375
Step 36, Loss: 1185.5947265625
Step 37, Loss: 1482.4945068359375
Step 38, Loss: 1181.853759765625
Step 39, Loss: 1394.738525390625
Step 40, Loss: 1645.787109375
Step 41, Loss: 1171.0008544921875
Step 42, Loss: 1287.83935546875
Step 43, Loss: 993.2245483398438
Step 44, Loss: 1151.0255126953125
Step 45, Loss: 1318.048583984375
Step 46, Loss: 924.6395263671875
Step 47, Loss: 1511.181396484375
Step 48, Loss: 1202.577880859375
Step 49, Loss: 1508.792236328125
Step 50, Loss: 1505.3642578125
Step 51, Loss: 1560.315673828125
Step 52, Loss: 1317.7265625
Step 53, Loss: 1494.878662109375
Step 54, Loss: 1350.65966796875
Step 55, Loss: 1119.683837890625
Step 56, Loss: 1108.71435546875
Step 57, Loss: 1348.2799072265625
Step 58, Loss: 1156.2025146484375
Step 59, Loss: 1242.4732666015625
Step 60, Loss: 1325.41650390625
Step 61, Loss: 1322.845703125
Step 62, Loss: 1502.64111328125
Step 63, Loss: 1255.45654296875
Step 64, Loss: 1404.2796630859375
Step 65, Loss: 1595.4014892578125
Step 66, Loss: 1242.0611572265625
Step 67, Loss: 1469.566162109375
Step 68, Loss: 1601.4737548828125
Step 69, Loss: 1300.706787109375
Step 70, Loss: 1010.876220703125
Step 71, Loss: 1304.723876953125
Step 72, Loss: 1525.03662109375
Step 73, Loss: 973.072265625
Step 74, Loss: 1136.6474609375
Step 75, Loss: 1350.9725341796875
Step 76, Loss: 1347.3695068359375
Step 77, Loss: 1024.8515625
Step 78, Loss: 1117.653076171875
Step 79, Loss: 1313.831787109375
Step 80, Loss: 1625.9874267578125
Step 81, Loss: 1471.9322509765625
Step 82, Loss: 2217.120361328125
Step 83, Loss: 1050.25341796875
Step 84, Loss: 1215.480224609375
Step 85, Loss: 1341.806640625
Step 86, Loss: 1674.99072265625
Step 87, Loss: 1209.512451171875
Step 88, Loss: 1633.779296875
Step 89, Loss: 1832.9788818359375
Step 90, Loss: 1328.9354248046875
Step 91, Loss: 1182.6014404296875
Step 92, Loss: 1966.3428955078125
Step 93, Loss: 1258.536865234375
Step 94, Loss: 1890.312744140625
Step 95, Loss: 1252.579345703125
Step 96, Loss: 1126.8079833984375
Step 97, Loss: 1239.4892578125
Step 98, Loss: 1986.4447021484375
Step 99, Loss: 1606.381591796875
Step 100, Loss: 1837.0020751953125
Step 101, Loss: 1673.17578125
Step 102, Loss: 1055.4329833984375
Step 103, Loss: 1901.669189453125
Step 104, Loss: 1321.781494140625
Step 105, Loss: 1454.352294921875
Step 106, Loss: 1967.356201171875
Step 107, Loss: 1142.036376953125
Step 108, Loss: 1747.320556640625
Step 109, Loss: 1317.83447265625
Step 110, Loss: 1302.368408203125
Step 111, Loss: 2881.3740234375
Step 112, Loss: 1421.568603515625
Step 113, Loss: 1219.75
Step 114, Loss: 1560.2093505859375
Step 115, Loss: 1460.766845703125
Step 116, Loss: 1280.051513671875
Step 117, Loss: 1520.0809326171875
Step 118, Loss: 1670.777587890625
Step 119, Loss: 2165.53662109375
Step 120, Loss: 1461.3638916015625
Step 121, Loss: 1569.815185546875
Step 122, Loss: 1206.189697265625
Step 123, Loss: 1708.3341064453125
Step 124, Loss: 1029.6060791015625
Step 125, Loss: 2053.216796875
Step 126, Loss: 1506.0224609375
Step 127, Loss: 1525.813720703125
Step 128, Loss: 1918.3447265625
Step 129, Loss: 1060.704345703125
Step 130, Loss: 1065.752197265625
Step 131, Loss: 2078.143310546875
Step 132, Loss: 2007.9171142578125
Step 133, Loss: 1194.0960693359375
Step 134, Loss: 1399.7044677734375
Step 135, Loss: 1465.194580078125
Step 136, Loss: 1657.65478515625
Step 137, Loss: 1398.98779296875
Step 138, Loss: 1893.5042724609375
Step 139, Loss: 1590.546630859375
Step 140, Loss: 1150.5404052734375
Step 141, Loss: 1262.00341796875
Step 142, Loss: 1475.8515625
Step 143, Loss: 1419.84619140625
Step 144, Loss: 1982.8699951171875
Step 145, Loss: 1801.156494140625
Step 146, Loss: 1469.816650390625
Step 147, Loss: 1709.6822509765625
Step 148, Loss: 2077.054443359375
Step 149, Loss: 1249.806396484375
Step 150, Loss: 1342.635009765625
Step 151, Loss: 1223.06884765625
Step 152, Loss: 1386.104248046875
Step 153, Loss: 1990.470458984375
Step 154, Loss: 1389.048095703125
Step 155, Loss: 1889.6376953125
Step 156, Loss: 2174.4853515625
Step 157, Loss: 1561.5816650390625
Step 158, Loss: 1745.8248291015625
Step 159, Loss: 1040.3082275390625
Step 160, Loss: 1628.5341796875
Step 161, Loss: 1864.4423828125
Step 162, Loss: 1203.421142578125
Step 163, Loss: 1344.75390625
Step 164, Loss: 1595.09033203125
Step 165, Loss: 1238.3028564453125
Step 166, Loss: 1102.130859375
Step 167, Loss: 1766.8060302734375
Step 168, Loss: 1282.4505615234375
Step 169, Loss: 1372.6455078125
Step 170, Loss: 1287.114501953125
Step 171, Loss: 2015.166015625
Step 172, Loss: 1478.1759033203125
Step 173, Loss: 1425.34765625
Step 174, Loss: 1295.652587890625
Step 175, Loss: 1922.87060546875
Step 176, Loss: 1617.073974609375
Step 177, Loss: 1533.184326171875
Step 178, Loss: 1334.359619140625
Step 179, Loss: 1599.361328125
Step 180, Loss: 1004.377197265625
Step 181, Loss: 1357.1324462890625
Step 182, Loss: 1589.1905517578125
Step 183, Loss: 1594.9830322265625
Step 184, Loss: 1796.361328125
Step 185, Loss: 1063.1932373046875
Step 186, Loss: 1612.2103271484375
Step 187, Loss: 1179.7874755859375
Step 188, Loss: 1238.40380859375
Step 189, Loss: 1625.638916015625
Step 190, Loss: 1425.85595703125
Step 191, Loss: 1521.1871337890625
Step 192, Loss: 1244.6064453125
Step 193, Loss: 1631.3555908203125
Step 194, Loss: 1567.26611328125
Step 195, Loss: 1530.6263427734375
Step 196, Loss: 1284.45361328125
Step 197, Loss: 1493.089599609375
Step 198, Loss: 1071.772216796875
Step 199, Loss: 1706.98193359375
Step 200, Loss: 1632.34326171875
Step 201, Loss: 2103.3896484375
Step 202, Loss: 954.9825439453125
Step 203, Loss: 1323.804443359375
Step 204, Loss: 1528.763916015625
Step 205, Loss: 2031.8448486328125
Step 206, Loss: 1118.1773681640625
Step 207, Loss: 1317.7890625
Step 208, Loss: 1503.223388671875
Step 209, Loss: 1808.2283935546875
Step 210, Loss: 1703.98974609375
Step 211, Loss: 1049.0858154296875
Step 212, Loss: 1275.987060546875
Step 213, Loss: 1196.9306640625
Step 214, Loss: 1411.6932373046875
Step 215, Loss: 1246.766357421875
Step 216, Loss: 987.7606811523438
Step 217, Loss: 1669.424560546875
Step 218, Loss: 1411.1510009765625
Step 219, Loss: 1460.197998046875
Step 220, Loss: 1075.899169921875
Step 221, Loss: 1218.0361328125
Step 222, Loss: 1453.509033203125
Step 223, Loss: 1131.749267578125
Step 224, Loss: 1093.9444580078125
Step 225, Loss: 1427.131591796875
Step 226, Loss: 1485.5362548828125
Step 227, Loss: 1257.9765625
Step 228, Loss: 831.0469970703125
Step 229, Loss: 1500.1734619140625
Step 230, Loss: 883.95947265625
Step 231, Loss: 1360.126708984375
Step 232, Loss: 1227.45166015625
Step 233, Loss: 1387.870849609375
Step 234, Loss: 2018.3006591796875
Step 235, Loss: 1122.8701171875
Step 236, Loss: 1076.186279296875
Step 237, Loss: 1436.813720703125
Step 238, Loss: 1551.4228515625
Step 239, Loss: 1010.64208984375
Step 240, Loss: 1146.5865478515625
Step 241, Loss: 1784.942138671875
Step 242, Loss: 1605.85986328125
Step 243, Loss: 1865.7120361328125
Step 244, Loss: 1131.7679443359375
Step 245, Loss: 1586.39501953125
Step 246, Loss: 1053.21240234375
Step 247, Loss: 1830.83203125
Step 248, Loss: 1143.00146484375
Step 249, Loss: 1098.12890625
Step 250, Loss: 1623.798583984375
Step 251, Loss: 931.1444702148438
Step 252, Loss: 1191.5740966796875
Step 253, Loss: 1232.3299560546875
Step 254, Loss: 1454.0263671875
Step 255, Loss: 1281.9136962890625
Step 256, Loss: 1520.6201171875
Step 257, Loss: 1286.822998046875
Step 258, Loss: 1148.400146484375
Step 259, Loss: 1985.4564208984375
Step 260, Loss: 1180.51220703125
Step 261, Loss: 1593.327880859375
Step 262, Loss: 1154.2291259765625
Step 263, Loss: 1704.37353515625
Step 264, Loss: 1198.2415771484375
Step 265, Loss: 979.8606567382812
Step 266, Loss: 966.625732421875
Step 267, Loss: 1343.4063720703125
Step 268, Loss: 1128.93701171875
Step 269, Loss: 1148.60986328125
Step 270, Loss: 1303.980224609375
Step 271, Loss: 1707.3887939453125
Step 272, Loss: 1212.033203125
Step 273, Loss: 1485.076416015625
Step 274, Loss: 802.6442260742188
Step 275, Loss: 1658.125
Step 276, Loss: 1155.947265625
Step 277, Loss: 1341.27978515625
Step 278, Loss: 1107.769775390625
Step 279, Loss: 1477.561767578125
Step 280, Loss: 1592.630126953125
Step 281, Loss: 1608.114501953125
Step 282, Loss: 1213.718017578125
Step 283, Loss: 1298.125
Step 284, Loss: 1297.184326171875
Step 285, Loss: 1524.143798828125
Step 286, Loss: 2174.5380859375
Step 287, Loss: 1597.50244140625
Step 288, Loss: 959.2469482421875
Step 289, Loss: 1169.5859375
Step 290, Loss: 1442.810302734375
Step 291, Loss: 1034.099853515625
Step 292, Loss: 998.4970703125
Step 293, Loss: 1627.68798828125
Step 294, Loss: 1459.016845703125
Step 295, Loss: 1093.124267578125
Step 296, Loss: 1057.6783447265625
Step 297, Loss: 1432.466796875
Step 298, Loss: 1088.4395751953125
Step 299, Loss: 1693.2427978515625
Step 300, Loss: 1282.32421875
Step 301, Loss: 920.5098876953125
Step 302, Loss: 1827.253662109375
Step 303, Loss: 1689.7030029296875
Step 304, Loss: 1525.353515625
Step 305, Loss: 1677.1201171875
Step 306, Loss: 1784.0885009765625
Step 307, Loss: 1434.113037109375
Step 308, Loss: 1304.2841796875
Step 309, Loss: 1061.751953125
Step 310, Loss: 1253.636474609375
Step 311, Loss: 1144.22802734375
Step 312, Loss: 1889.0247802734375
Step 313, Loss: 1265.8465576171875
Step 314, Loss: 1193.4488525390625
Step 315, Loss: 1813.67578125
Step 316, Loss: 1387.3717041015625
Step 317, Loss: 1652.5
Step 318, Loss: 2130.800537109375
Step 319, Loss: 1272.6710205078125
Step 320, Loss: 1232.8660888671875
Step 321, Loss: 1048.517822265625
Step 322, Loss: 1738.4168701171875
Step 323, Loss: 994.63671875
Step 324, Loss: 1520.3111572265625
Step 325, Loss: 1712.8887939453125
Step 326, Loss: 1547.9693603515625
Step 327, Loss: 1156.1005859375
Step 328, Loss: 1533.877197265625
Step 329, Loss: 1518.6517333984375
Step 330, Loss: 1046.860595703125
Step 331, Loss: 1764.745361328125
Step 332, Loss: 1499.9609375
Step 333, Loss: 1422.4716796875
Step 334, Loss: 1091.22998046875
Step 335, Loss: 873.0045166015625
Step 336, Loss: 1054.060302734375
Step 337, Loss: 1297.4365234375
Step 338, Loss: 1910.373291015625
Step 339, Loss: 1235.56494140625
Step 340, Loss: 1540.74609375
Step 341, Loss: 1384.647705078125
Step 342, Loss: 1600.019775390625
Step 343, Loss: 1365.9439697265625
Step 344, Loss: 1584.5767822265625
Step 345, Loss: 1595.54736328125
Step 346, Loss: 1569.956298828125
Step 347, Loss: 1193.80712890625
Step 348, Loss: 953.3492431640625
Step 349, Loss: 1395.6644287109375
Step 350, Loss: 1400.4874267578125
Step 351, Loss: 1451.081298828125
Step 352, Loss: 1007.0923461914062
Step 353, Loss: 1332.080322265625
Step 354, Loss: 1357.7301025390625
Step 355, Loss: 953.5802001953125
Step 356, Loss: 1283.988525390625
Step 357, Loss: 1157.7041015625
Step 358, Loss: 1377.642822265625
Step 359, Loss: 1074.5054931640625
Step 360, Loss: 1783.91455078125
Step 361, Loss: 1339.19677734375
Step 362, Loss: 1765.009765625
Step 363, Loss: 1154.8880615234375
Step 364, Loss: 1303.498291015625
Step 365, Loss: 1372.677734375
Step 366, Loss: 1622.110107421875
Step 367, Loss: 1271.9083251953125
Step 368, Loss: 1147.957763671875
Step 369, Loss: 1974.82080078125
Step 370, Loss: 1554.877197265625
Step 371, Loss: 1458.959228515625
Step 372, Loss: 1624.116455078125
Step 373, Loss: 1260.9439697265625
Step 374, Loss: 1553.80419921875
Step 375, Loss: 1050.49462890625
Step 376, Loss: 1284.7432861328125
Step 377, Loss: 1408.9525146484375
Step 378, Loss: 1273.099609375
Step 379, Loss: 1459.145751953125
Step 380, Loss: 1198.253173828125
Step 381, Loss: 1334.21826171875
Step 382, Loss: 1057.520263671875
Step 383, Loss: 1203.5245361328125
Step 384, Loss: 1140.5009765625
Step 385, Loss: 1093.99462890625
Step 386, Loss: 1212.078857421875
Step 387, Loss: 1388.4498291015625
Step 388, Loss: 1805.5279541015625
Step 389, Loss: 1513.4420166015625
Step 390, Loss: 1375.2115478515625
Step 391, Loss: 1075.574951171875
Step 392, Loss: 1794.258544921875
Step 393, Loss: 1790.400146484375
Step 394, Loss: 1223.4815673828125
Step 395, Loss: 1891.246826171875
Step 396, Loss: 1332.2364501953125
Step 397, Loss: 1290.03759765625
Step 398, Loss: 1758.0888671875
Step 399, Loss: 1058.5726318359375
Step 400, Loss: 912.8321533203125
Step 401, Loss: 1143.2447509765625
Step 402, Loss: 1383.248291015625
Step 403, Loss: 1202.1904296875
Step 404, Loss: 1598.3101806640625
Step 405, Loss: 1495.3203125
Step 406, Loss: 1449.4405517578125
Step 407, Loss: 1325.498779296875
Step 408, Loss: 1168.171875
Step 409, Loss: 830.4091796875
Step 410, Loss: 1674.7080078125
Step 411, Loss: 1399.978759765625
Step 412, Loss: 1705.4149169921875
Step 413, Loss: 2014.07421875
Step 414, Loss: 1577.784912109375
Step 415, Loss: 1314.164306640625
Step 416, Loss: 1251.5391845703125
Step 417, Loss: 1779.771728515625
Step 418, Loss: 1411.2890625
Step 419, Loss: 898.7115478515625
Step 420, Loss: 1513.54833984375
Step 421, Loss: 1109.6455078125
Step 422, Loss: 1550.4925537109375
Step 423, Loss: 1210.365966796875
Step 424, Loss: 1833.0389404296875
Step 425, Loss: 1037.49365234375
Step 426, Loss: 1230.4632568359375
Step 427, Loss: 1338.976318359375
Step 428, Loss: 1424.70947265625
Step 429, Loss: 1671.7451171875
Step 430, Loss: 1829.2230224609375
Step 431, Loss: 1668.832763671875
Step 432, Loss: 1546.8280029296875
Step 433, Loss: 1254.328369140625
Step 434, Loss: 1617.667724609375
Step 435, Loss: 1311.4676513671875
Step 436, Loss: 1078.8048095703125
Step 437, Loss: 1459.23828125
Step 438, Loss: 1391.421630859375
Step 439, Loss: 1024.3675537109375
Step 440, Loss: 1114.6915283203125
Step 441, Loss: 1341.411865234375
Step 442, Loss: 990.5977783203125
Step 443, Loss: 1176.3507080078125
Step 444, Loss: 1338.45361328125
Step 445, Loss: 1037.4556884765625
Step 446, Loss: 1386.7548828125
Step 447, Loss: 1723.485595703125
Step 448, Loss: 2248.262451171875
Step 449, Loss: 1507.2198486328125
Step 450, Loss: 1902.50732421875
Step 451, Loss: 1609.2783203125
Step 452, Loss: 883.104248046875
Step 453, Loss: 933.4012451171875
Step 454, Loss: 1770.635009765625
Step 455, Loss: 1091.0963134765625
Step 456, Loss: 1338.765625
Step 457, Loss: 1500.919677734375
Step 458, Loss: 1115.61279296875
Step 459, Loss: 1865.16357421875
Step 460, Loss: 1275.708740234375
Step 461, Loss: 1856.6123046875
Step 462, Loss: 1179.163330078125
Step 463, Loss: 785.0582275390625
Step 464, Loss: 1065.76708984375
Step 465, Loss: 1140.35302734375
Step 466, Loss: 1262.082275390625
Step 467, Loss: 1484.341064453125
Step 468, Loss: 1126.927001953125
Step 469, Loss: 1812.801513671875
Step 470, Loss: 1583.437255859375
Step 471, Loss: 1389.092529296875
Step 472, Loss: 973.5128173828125
Step 473, Loss: 2022.6402587890625
Step 474, Loss: 1326.77734375
Step 475, Loss: 1460.857421875
Step 476, Loss: 1146.515869140625
Step 477, Loss: 975.498291015625
Step 478, Loss: 900.0221557617188
Step 479, Loss: 1166.374267578125
Step 480, Loss: 1737.30126953125
Step 481, Loss: 1062.071044921875
Step 482, Loss: 1642.854248046875
Step 483, Loss: 1408.7431640625
Step 484, Loss: 1588.709716796875
Step 485, Loss: 895.8580322265625
Step 486, Loss: 1766.7691650390625
Step 487, Loss: 863.95654296875
Step 488, Loss: 1673.743408203125
Step 489, Loss: 1306.481201171875
Step 490, Loss: 1376.691650390625
Step 491, Loss: 1305.647705078125
Step 492, Loss: 1177.880859375
Step 493, Loss: 1562.1954345703125
Step 494, Loss: 2219.90478515625
Step 495, Loss: 911.30419921875
Step 496, Loss: 1514.813232421875
Step 497, Loss: 1346.8331298828125
Step 498, Loss: 991.4085693359375
Step 499, Loss: 977.5787353515625
Step 500, Loss: 1657.347412109375
Step 501, Loss: 1287.251953125
Step 502, Loss: 743.5289306640625
Step 503, Loss: 1239.245361328125
Step 504, Loss: 1049.7587890625
Step 505, Loss: 1977.7630615234375
Step 506, Loss: 1106.8863525390625
Step 507, Loss: 1745.14404296875
Step 508, Loss: 1097.395751953125
Step 509, Loss: 1811.0892333984375
Step 510, Loss: 1231.067626953125
Step 511, Loss: 1381.961669921875
Step 512, Loss: 1410.78662109375
Step 513, Loss: 1118.81787109375
Step 514, Loss: 1412.028076171875
Step 515, Loss: 1493.8941650390625
Step 516, Loss: 1394.9820556640625
Step 517, Loss: 1060.9759521484375
Step 518, Loss: 960.06982421875
Step 519, Loss: 1316.806884765625
Step 520, Loss: 1471.657958984375
Step 521, Loss: 1293.024658203125
Step 522, Loss: 1042.2998046875
Step 523, Loss: 1133.102294921875
Step 524, Loss: 1363.699951171875
Step 525, Loss: 1608.5966796875
Step 526, Loss: 1173.14794921875
Step 527, Loss: 1066.2161865234375
Step 528, Loss: 1780.4852294921875
Step 529, Loss: 1444.552978515625
Step 530, Loss: 942.1851806640625
Step 531, Loss: 1389.2520751953125
Step 532, Loss: 1434.215087890625
Step 533, Loss: 1865.0098876953125
Step 534, Loss: 1295.721435546875
Step 535, Loss: 1056.942138671875
Step 536, Loss: 1707.695556640625
Step 537, Loss: 1559.7491455078125
Step 538, Loss: 1124.744873046875
Step 539, Loss: 1057.86767578125
Step 540, Loss: 1226.157470703125
Step 541, Loss: 1371.0052490234375
Step 542, Loss: 1402.97998046875
Step 543, Loss: 1217.48876953125
Step 544, Loss: 1207.627685546875
Step 545, Loss: 1065.9107666015625
Step 546, Loss: 1155.632080078125
Step 547, Loss: 1536.2127685546875
Step 548, Loss: 1335.600341796875
Step 549, Loss: 1116.26904296875
Step 550, Loss: 1721.7852783203125
Step 551, Loss: 1162.0531005859375
Step 552, Loss: 1553.541015625
Step 553, Loss: 1483.7789306640625
Step 554, Loss: 1179.0186767578125
Step 555, Loss: 1394.5513916015625
Step 556, Loss: 1367.4727783203125
Step 557, Loss: 1133.940673828125
Step 558, Loss: 945.0025024414062
Step 559, Loss: 1859.411865234375
Step 560, Loss: 1189.91650390625
Step 561, Loss: 1205.9249267578125
Step 562, Loss: 1007.8544311523438
Step 563, Loss: 1241.7750244140625
Step 564, Loss: 1296.9080810546875
Step 565, Loss: 1383.333740234375
Step 566, Loss: 2233.53515625
Step 567, Loss: 1587.458251953125
Step 568, Loss: 1825.3814697265625
Step 569, Loss: 1371.637451171875
Step 570, Loss: 1573.197265625
Step 571, Loss: 1047.68017578125
Step 572, Loss: 907.6341552734375
Step 573, Loss: 1212.4932861328125
Step 574, Loss: 2112.6806640625
Step 575, Loss: 950.9673461914062
Step 576, Loss: 1824.330810546875
Step 577, Loss: 1359.4169921875
Step 578, Loss: 1560.9697265625
Step 579, Loss: 1049.075439453125
Step 580, Loss: 1353.3956298828125
Step 581, Loss: 1516.19921875
Step 582, Loss: 1597.74951171875
Step 583, Loss: 1127.2060546875
Step 584, Loss: 1558.127685546875
Step 585, Loss: 1831.9423828125
Step 586, Loss: 1555.7625732421875
Step 587, Loss: 1547.912109375
Step 588, Loss: 1655.69140625
Step 589, Loss: 1040.394287109375
Step 590, Loss: 1128.268310546875
Step 591, Loss: 1267.056884765625
Step 592, Loss: 1277.0677490234375
Step 593, Loss: 1127.994384765625
Step 594, Loss: 1245.7625732421875
Step 595, Loss: 1278.789306640625
Step 596, Loss: 2345.984619140625
Step 597, Loss: 2026.3302001953125
Step 598, Loss: 1655.1656494140625
Step 599, Loss: 1053.072265625
Step 600, Loss: 1770.6749267578125
Step 601, Loss: 992.5887451171875
Step 602, Loss: 1830.2685546875
Step 603, Loss: 1374.114990234375
Step 604, Loss: 1073.7449951171875
Step 605, Loss: 1183.1143798828125
Step 606, Loss: 1291.30322265625
Step 607, Loss: 1838.1009521484375
Step 608, Loss: 1740.20556640625
Step 609, Loss: 937.88232421875
Step 610, Loss: 1357.314697265625
Step 611, Loss: 1289.94921875
Step 612, Loss: 1513.4420166015625
Step 613, Loss: 1593.559814453125
Step 614, Loss: 1200.7783203125
Step 615, Loss: 1296.7430419921875
Step 616, Loss: 1384.037109375
Step 617, Loss: 995.932861328125
Step 618, Loss: 1321.1334228515625
Step 619, Loss: 1980.3076171875
Step 620, Loss: 1491.490966796875
Step 621, Loss: 1263.7740478515625
Step 622, Loss: 1492.602783203125
Step 623, Loss: 1138.0784912109375
Step 624, Loss: 1257.24072265625
Step 625, Loss: 1708.15283203125
In [ ]:
# torch.save(cm_unet.state_dict(), 'mbcd_model.pth')

И в последний раз сэмплируем¶

Важно: теперь у нас появляется возможно сэмплировать детерминистично с помощью оригинального солверва DDIM за 4 шага. Так что возвращаем сэмплирование исходным pipe-ом.

Ниже прикрепляем референс и напомним, что у вас картинки могут отличаться и быть чуть хуже/лучше. img

In [ ]:
pipe.unet = cm_unet.eval().to(torch.float16)
assert cm_unet.active_adapter == 'multi-cd'

guidance_scale = 1

for prompt in validation_prompts:
    generator = torch.Generator(device="cuda").manual_seed(1)

    images = pipe(
        prompt=prompt,
        num_inference_steps=4,
        num_images_per_prompt=4,
        generator=generator,
        guidance_scale=guidance_scale,
    ).images

    visualize_images(images)
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

Возвращаемся в колаб к остальным адаптерам¶

In [ ]:
# Загрузка предыдущих адаптеров
load = True
if load:
    cm_unet = get_peft_model(unet, lora_config, adapter_name="ct")

    cm_unet.enable_gradient_checkpointing()

    cm_unet.load_state_dict(torch.load('/content/drive/MyDrive/cv_model/cm_model.pth'))

    # Добавляем адаптер "cd" и загружаем его параметры
    cm_unet.add_adapter("cd", lora_config)
    cm_unet.set_adapter("cd")
    cm_unet.load_state_dict(torch.load('/content/drive/MyDrive/cv_model/cd_model.pth'))

    # Добавляем адаптер "multi-cd" и загружаем его параметры
    cm_unet.add_adapter("multi-cd", lora_config)
    cm_unet.set_adapter("multi-cd")
    cm_unet.load_state_dict(torch.load('/content/drive/MyDrive/cv_model/mbcd_model.pth'))

Задание №8¶

Все, что осталось сделать - это загрузить ваши обученные модельки на huggingface_hub. Это очень популярный и удобный способ для хранения моделей, которые легко можно загружать и подставлять в модель. Другими словами GitHub для моделей и датасетов.

  1. Создайте аккаунт на huggingface.co

  2. Получите свой HF токен, который можно получить здесь: https://huggingface.co/settings/tokens

  3. Создайте репозиторий для ваших моделями https://huggingface.co/new

Важно: перед отправкой нотбука на проверку, не забудьте удалить свой hf токен!

In [19]:
cm_unet.push_to_hub(
    "eteron/cv_week_final_model", # "<username>/<repo-name>"
    token=my_token
)
README.md:   0%|          | 0.00/31.0 [00:00<?, ?B/s]
adapter_model.safetensors:   0%|          | 0.00/538M [00:00<?, ?B/s]
adapter_model.safetensors:   0%|          | 0.00/269M [00:00<?, ?B/s]
Upload 3 LFS files:   0%|          | 0/3 [00:00<?, ?it/s]
adapter_model.safetensors:   0%|          | 0.00/269M [00:00<?, ?B/s]
Out[19]:
CommitInfo(commit_url='https://huggingface.co/eteron/cv_week_final_model/commit/7415ad7aa7101a82e3c0ab56c3bdc325330e8c80', commit_message='Upload model', commit_description='', oid='7415ad7aa7101a82e3c0ab56c3bdc325330e8c80', pr_url=None, repo_url=RepoUrl('https://huggingface.co/eteron/cv_week_final_model', endpoint='https://huggingface.co', repo_type='model', repo_id='eteron/cv_week_final_model'), pr_revision=None, pr_num=None)

Пример, как должен выглядеть результат выполнения команды: https://huggingface.co/dbaranchuk/cv-week-final-task-example

Давайте проверим, что загрузка модели корректно работает.

In [20]:
from peft import PeftModel

loaded_cm_unet = PeftModel.from_pretrained(
    unet,
    "eteron/cv_week_final_model",
    token=my_token,
    subfolder='multi-cd',
    adapter_name="multi-cd",
)
multi-cd/adapter_config.json:   0%|          | 0.00/945 [00:00<?, ?B/s]
adapter_model.safetensors:   0%|          | 0.00/269M [00:00<?, ?B/s]
In [23]:
pipe.unet = loaded_cm_unet.eval().to(torch.float16)
assert loaded_cm_unet.active_adapter == 'multi-cd'

guidance_scale = 1

for prompt in validation_prompts:
    generator = torch.Generator(device="cuda").manual_seed(1)

    images = pipe(
        prompt=prompt,
        num_inference_steps=4,
        num_images_per_prompt=4,
        generator=generator,
        guidance_scale=guidance_scale,
    ).images

    visualize_images(images)
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

На этом все! Ура!

No description has been provided for this image